Add reauth support to Schlage (#103351)

* Add reauth support to Schlage

* Enforce same user credentials are used on reauth

* Changes requested during review

* Changes requested during review

* Add password to reauth_confirm data
pull/104063/head
David Knowles 2023-11-16 02:47:13 -05:00 committed by GitHub
parent 1c817cc18c
commit e10c5246b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 228 additions and 31 deletions

View File

@ -7,8 +7,9 @@ import pyschlage
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_PASSWORD, CONF_USERNAME, Platform from homeassistant.const import CONF_PASSWORD, CONF_USERNAME, Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed
from .const import DOMAIN, LOGGER from .const import DOMAIN
from .coordinator import SchlageDataUpdateCoordinator from .coordinator import SchlageDataUpdateCoordinator
PLATFORMS: list[Platform] = [ PLATFORMS: list[Platform] = [
@ -26,8 +27,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
try: try:
auth = await hass.async_add_executor_job(pyschlage.Auth, username, password) auth = await hass.async_add_executor_job(pyschlage.Auth, username, password)
except WarrantException as ex: except WarrantException as ex:
LOGGER.error("Schlage authentication failed: %s", ex) raise ConfigEntryAuthFailed from ex
return False
coordinator = SchlageDataUpdateCoordinator(hass, username, pyschlage.Schlage(auth)) coordinator = SchlageDataUpdateCoordinator(hass, username, pyschlage.Schlage(auth))
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator

View File

@ -1,6 +1,7 @@
"""Config flow for Schlage integration.""" """Config flow for Schlage integration."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping
from typing import Any from typing import Any
import pyschlage import pyschlage
@ -8,6 +9,7 @@ from pyschlage.exceptions import NotAuthorizedError
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_PASSWORD, CONF_USERNAME from homeassistant.const import CONF_PASSWORD, CONF_USERNAME
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
@ -16,6 +18,7 @@ from .const import DOMAIN, LOGGER
STEP_USER_DATA_SCHEMA = vol.Schema( STEP_USER_DATA_SCHEMA = vol.Schema(
{vol.Required(CONF_USERNAME): str, vol.Required(CONF_PASSWORD): str} {vol.Required(CONF_USERNAME): str, vol.Required(CONF_PASSWORD): str}
) )
STEP_REAUTH_DATA_SCHEMA = vol.Schema({vol.Required(CONF_PASSWORD): str})
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
@ -23,36 +26,88 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
VERSION = 1 VERSION = 1
reauth_entry: ConfigEntry | None = None
async def async_step_user( async def async_step_user(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> FlowResult: ) -> FlowResult:
"""Handle the initial step.""" """Handle the initial step."""
errors: dict[str, str] = {} if user_input is None:
if user_input is not None: return self._show_user_form({})
username = user_input[CONF_USERNAME] username = user_input[CONF_USERNAME]
password = user_input[CONF_PASSWORD] password = user_input[CONF_PASSWORD]
try: user_id, errors = await self.hass.async_add_executor_job(
user_id = await self.hass.async_add_executor_job( _authenticate, username, password
_authenticate, username, password )
) if user_id is None:
except NotAuthorizedError: return self._show_user_form(errors)
errors["base"] = "invalid_auth"
except Exception: # pylint: disable=broad-except
LOGGER.exception("Unknown error")
errors["base"] = "unknown"
else:
await self.async_set_unique_id(user_id)
return self.async_create_entry(title=username, data=user_input)
await self.async_set_unique_id(user_id)
return self.async_create_entry(title=username, data=user_input)
def _show_user_form(self, errors: dict[str, str]) -> FlowResult:
"""Show the user form."""
return self.async_show_form( return self.async_show_form(
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
) )
async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
"""Handle reauth upon an API authentication error."""
self.reauth_entry = self.hass.config_entries.async_get_entry(
self.context["entry_id"]
)
return await self.async_step_reauth_confirm()
def _authenticate(username: str, password: str) -> str: async def async_step_reauth_confirm(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Dialog that informs the user that reauth is required."""
assert self.reauth_entry is not None
if user_input is None:
return self._show_reauth_form({})
username = self.reauth_entry.data[CONF_USERNAME]
password = user_input[CONF_PASSWORD]
user_id, errors = await self.hass.async_add_executor_job(
_authenticate, username, password
)
if user_id is None:
return self._show_reauth_form(errors)
if self.reauth_entry.unique_id != user_id:
return self.async_abort(reason="wrong_account")
data = {
CONF_USERNAME: username,
CONF_PASSWORD: user_input[CONF_PASSWORD],
}
self.hass.config_entries.async_update_entry(self.reauth_entry, data=data)
await self.hass.config_entries.async_reload(self.reauth_entry.entry_id)
return self.async_abort(reason="reauth_successful")
def _show_reauth_form(self, errors: dict[str, str]) -> FlowResult:
"""Show the reauth form."""
return self.async_show_form(
step_id="reauth_confirm",
data_schema=STEP_REAUTH_DATA_SCHEMA,
errors=errors,
)
def _authenticate(username: str, password: str) -> tuple[str | None, dict[str, str]]:
"""Authenticate with the Schlage API.""" """Authenticate with the Schlage API."""
auth = pyschlage.Auth(username, password) user_id = None
auth.authenticate() errors: dict[str, str] = {}
# The user_id property will make a blocking call if it's not already try:
# cached. To avoid blocking the event loop, we read it here. auth = pyschlage.Auth(username, password)
return auth.user_id auth.authenticate()
except NotAuthorizedError:
errors["base"] = "invalid_auth"
except Exception: # pylint: disable=broad-except
LOGGER.exception("Unknown error")
errors["base"] = "unknown"
else:
# The user_id property will make a blocking call if it's not already
# cached. To avoid blocking the event loop, we read it here.
user_id = auth.user_id
return user_id, errors

View File

@ -5,10 +5,11 @@ import asyncio
from dataclasses import dataclass from dataclasses import dataclass
from pyschlage import Lock, Schlage from pyschlage import Lock, Schlage
from pyschlage.exceptions import Error as SchlageError from pyschlage.exceptions import Error as SchlageError, NotAuthorizedError
from pyschlage.log import LockLog from pyschlage.log import LockLog
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
from .const import DOMAIN, LOGGER, UPDATE_INTERVAL from .const import DOMAIN, LOGGER, UPDATE_INTERVAL
@ -43,6 +44,8 @@ class SchlageDataUpdateCoordinator(DataUpdateCoordinator[SchlageData]):
"""Fetch the latest data from the Schlage API.""" """Fetch the latest data from the Schlage API."""
try: try:
locks = await self.hass.async_add_executor_job(self.api.locks) locks = await self.hass.async_add_executor_job(self.api.locks)
except NotAuthorizedError as ex:
raise ConfigEntryAuthFailed from ex
except SchlageError as ex: except SchlageError as ex:
raise UpdateFailed("Failed to refresh Schlage data") from ex raise UpdateFailed("Failed to refresh Schlage data") from ex
lock_data = await asyncio.gather( lock_data = await asyncio.gather(
@ -64,6 +67,8 @@ class SchlageDataUpdateCoordinator(DataUpdateCoordinator[SchlageData]):
logs = previous_lock_data.logs logs = previous_lock_data.logs
try: try:
logs = lock.logs() logs = lock.logs()
except NotAuthorizedError as ex:
raise ConfigEntryAuthFailed from ex
except SchlageError as ex: except SchlageError as ex:
LOGGER.debug('Failed to read logs for lock "%s": %s', lock.name, ex) LOGGER.debug('Failed to read logs for lock "%s": %s', lock.name, ex)

View File

@ -6,6 +6,13 @@
"username": "[%key:common::config_flow::data::username%]", "username": "[%key:common::config_flow::data::username%]",
"password": "[%key:common::config_flow::data::password%]" "password": "[%key:common::config_flow::data::password%]"
} }
},
"reauth_confirm": {
"title": "[%key:common::config_flow::title::reauth%]",
"description": "The Schlage integration needs to re-authenticate your account",
"data": {
"password": "[%key:common::config_flow::data::password%]"
}
} }
}, },
"error": { "error": {
@ -13,7 +20,9 @@
"unknown": "[%key:common::config_flow::error::unknown%]" "unknown": "[%key:common::config_flow::error::unknown%]"
}, },
"abort": { "abort": {
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]" "already_configured": "[%key:common::config_flow::abort::already_configured_device%]",
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]",
"wrong_account": "The user credentials provided do not match this Schlage account."
} }
}, },
"entity": { "entity": {

View File

@ -54,14 +54,14 @@ def mock_setup_entry() -> Generator[AsyncMock, None, None]:
@pytest.fixture @pytest.fixture
def mock_schlage(): def mock_schlage() -> Mock:
"""Mock pyschlage.Schlage.""" """Mock pyschlage.Schlage."""
with patch("pyschlage.Schlage", autospec=True) as mock_schlage: with patch("pyschlage.Schlage", autospec=True) as mock_schlage:
yield mock_schlage.return_value yield mock_schlage.return_value
@pytest.fixture @pytest.fixture
def mock_pyschlage_auth(): def mock_pyschlage_auth() -> Mock:
"""Mock pyschlage.Auth.""" """Mock pyschlage.Auth."""
with patch("pyschlage.Auth", autospec=True) as mock_auth: with patch("pyschlage.Auth", autospec=True) as mock_auth:
mock_auth.return_value.user_id = "abc123" mock_auth.return_value.user_id = "abc123"
@ -69,7 +69,7 @@ def mock_pyschlage_auth():
@pytest.fixture @pytest.fixture
def mock_lock(): def mock_lock() -> Mock:
"""Mock Lock fixture.""" """Mock Lock fixture."""
mock_lock = create_autospec(Lock) mock_lock = create_autospec(Lock)
mock_lock.configure_mock( mock_lock.configure_mock(

View File

@ -9,6 +9,8 @@ from homeassistant.components.schlage.const import DOMAIN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
from tests.common import MockConfigEntry
pytestmark = pytest.mark.usefixtures("mock_setup_entry") pytestmark = pytest.mark.usefixtures("mock_setup_entry")
@ -78,3 +80,94 @@ async def test_form_unknown(hass: HomeAssistant, mock_pyschlage_auth: Mock) -> N
assert result2["type"] == FlowResultType.FORM assert result2["type"] == FlowResultType.FORM
assert result2["errors"] == {"base": "unknown"} assert result2["errors"] == {"base": "unknown"}
async def test_reauth(
hass: HomeAssistant,
mock_added_config_entry: MockConfigEntry,
mock_setup_entry: AsyncMock,
mock_pyschlage_auth: Mock,
) -> None:
"""Test reauth flow."""
mock_added_config_entry.async_start_reauth(hass)
await hass.async_block_till_done()
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
[result] = flows
assert result["step_id"] == "reauth_confirm"
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"password": "new-password"},
)
await hass.async_block_till_done()
mock_pyschlage_auth.authenticate.assert_called_once_with()
assert result2["type"] == FlowResultType.ABORT
assert result2["reason"] == "reauth_successful"
assert mock_added_config_entry.data == {
"username": "asdf@asdf.com",
"password": "new-password",
}
assert len(mock_setup_entry.mock_calls) == 1
async def test_reauth_invalid_auth(
hass: HomeAssistant,
mock_added_config_entry: MockConfigEntry,
mock_setup_entry: AsyncMock,
mock_pyschlage_auth: Mock,
) -> None:
"""Test reauth flow."""
mock_added_config_entry.async_start_reauth(hass)
await hass.async_block_till_done()
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
[result] = flows
assert result["step_id"] == "reauth_confirm"
mock_pyschlage_auth.authenticate.reset_mock()
mock_pyschlage_auth.authenticate.side_effect = NotAuthorizedError
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"password": "new-password"},
)
await hass.async_block_till_done()
mock_pyschlage_auth.authenticate.assert_called_once_with()
assert result2["type"] == FlowResultType.FORM
assert result2["errors"] == {"base": "invalid_auth"}
async def test_reauth_wrong_account(
hass: HomeAssistant,
mock_added_config_entry: MockConfigEntry,
mock_setup_entry: AsyncMock,
mock_pyschlage_auth: Mock,
) -> None:
"""Test reauth flow."""
mock_pyschlage_auth.user_id = "bad-user-id"
mock_added_config_entry.async_start_reauth(hass)
await hass.async_block_till_done()
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
[result] = flows
assert result["step_id"] == "reauth_confirm"
result2 = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"password": "new-password"},
)
await hass.async_block_till_done()
mock_pyschlage_auth.authenticate.assert_called_once_with()
assert result2["type"] == FlowResultType.ABORT
assert result2["reason"] == "wrong_account"
assert mock_added_config_entry.data == {
"username": "asdf@asdf.com",
"password": "hunter2",
}
assert len(mock_setup_entry.mock_calls) == 1

View File

@ -3,7 +3,7 @@
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from pycognito.exceptions import WarrantException from pycognito.exceptions import WarrantException
from pyschlage.exceptions import Error from pyschlage.exceptions import Error, NotAuthorizedError
from homeassistant.config_entries import ConfigEntryState from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -43,6 +43,41 @@ async def test_update_data_fails(
assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY
async def test_update_data_auth_error(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_pyschlage_auth: Mock,
mock_schlage: Mock,
) -> None:
"""Test that we properly handle API errors."""
mock_schlage.locks.side_effect = NotAuthorizedError
mock_config_entry.add_to_hass(hass)
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
assert mock_schlage.locks.call_count == 1
assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR
async def test_update_data_get_logs_auth_error(
hass: HomeAssistant,
mock_config_entry: MockConfigEntry,
mock_pyschlage_auth: Mock,
mock_schlage: Mock,
mock_lock: Mock,
) -> None:
"""Test that we properly handle API errors."""
mock_schlage.locks.return_value = [mock_lock]
mock_lock.logs.reset_mock()
mock_lock.logs.side_effect = NotAuthorizedError
mock_config_entry.add_to_hass(hass)
await hass.config_entries.async_setup(mock_config_entry.entry_id)
await hass.async_block_till_done()
assert mock_schlage.locks.call_count == 1
assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR
async def test_load_unload_config_entry( async def test_load_unload_config_entry(
hass: HomeAssistant, hass: HomeAssistant,
mock_config_entry: MockConfigEntry, mock_config_entry: MockConfigEntry,