Add reauth support to Schlage (#103351)
* Add reauth support to Schlage * Enforce same user credentials are used on reauth * Changes requested during review * Changes requested during review * Add password to reauth_confirm datapull/104063/head
parent
1c817cc18c
commit
e10c5246b9
|
@ -7,8 +7,9 @@ import pyschlage
|
|||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_PASSWORD, CONF_USERNAME, Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed
|
||||
|
||||
from .const import DOMAIN, LOGGER
|
||||
from .const import DOMAIN
|
||||
from .coordinator import SchlageDataUpdateCoordinator
|
||||
|
||||
PLATFORMS: list[Platform] = [
|
||||
|
@ -26,8 +27,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
try:
|
||||
auth = await hass.async_add_executor_job(pyschlage.Auth, username, password)
|
||||
except WarrantException as ex:
|
||||
LOGGER.error("Schlage authentication failed: %s", ex)
|
||||
return False
|
||||
raise ConfigEntryAuthFailed from ex
|
||||
|
||||
coordinator = SchlageDataUpdateCoordinator(hass, username, pyschlage.Schlage(auth))
|
||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Config flow for Schlage integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import pyschlage
|
||||
|
@ -8,6 +9,7 @@ from pyschlage.exceptions import NotAuthorizedError
|
|||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_PASSWORD, CONF_USERNAME
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
|
||||
|
@ -16,6 +18,7 @@ from .const import DOMAIN, LOGGER
|
|||
STEP_USER_DATA_SCHEMA = vol.Schema(
|
||||
{vol.Required(CONF_USERNAME): str, vol.Required(CONF_PASSWORD): str}
|
||||
)
|
||||
STEP_REAUTH_DATA_SCHEMA = vol.Schema({vol.Required(CONF_PASSWORD): str})
|
||||
|
||||
|
||||
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
|
@ -23,36 +26,88 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
|
||||
VERSION = 1
|
||||
|
||||
reauth_entry: ConfigEntry | None = None
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle the initial step."""
|
||||
errors: dict[str, str] = {}
|
||||
if user_input is not None:
|
||||
username = user_input[CONF_USERNAME]
|
||||
password = user_input[CONF_PASSWORD]
|
||||
try:
|
||||
user_id = await self.hass.async_add_executor_job(
|
||||
_authenticate, username, password
|
||||
)
|
||||
except NotAuthorizedError:
|
||||
errors["base"] = "invalid_auth"
|
||||
except Exception: # pylint: disable=broad-except
|
||||
LOGGER.exception("Unknown error")
|
||||
errors["base"] = "unknown"
|
||||
else:
|
||||
await self.async_set_unique_id(user_id)
|
||||
return self.async_create_entry(title=username, data=user_input)
|
||||
if user_input is None:
|
||||
return self._show_user_form({})
|
||||
username = user_input[CONF_USERNAME]
|
||||
password = user_input[CONF_PASSWORD]
|
||||
user_id, errors = await self.hass.async_add_executor_job(
|
||||
_authenticate, username, password
|
||||
)
|
||||
if user_id is None:
|
||||
return self._show_user_form(errors)
|
||||
|
||||
await self.async_set_unique_id(user_id)
|
||||
return self.async_create_entry(title=username, data=user_input)
|
||||
|
||||
def _show_user_form(self, errors: dict[str, str]) -> FlowResult:
|
||||
"""Show the user form."""
|
||||
return self.async_show_form(
|
||||
step_id="user", data_schema=STEP_USER_DATA_SCHEMA, errors=errors
|
||||
)
|
||||
|
||||
async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
|
||||
"""Handle reauth upon an API authentication error."""
|
||||
self.reauth_entry = self.hass.config_entries.async_get_entry(
|
||||
self.context["entry_id"]
|
||||
)
|
||||
return await self.async_step_reauth_confirm()
|
||||
|
||||
def _authenticate(username: str, password: str) -> str:
|
||||
async def async_step_reauth_confirm(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Dialog that informs the user that reauth is required."""
|
||||
assert self.reauth_entry is not None
|
||||
if user_input is None:
|
||||
return self._show_reauth_form({})
|
||||
|
||||
username = self.reauth_entry.data[CONF_USERNAME]
|
||||
password = user_input[CONF_PASSWORD]
|
||||
user_id, errors = await self.hass.async_add_executor_job(
|
||||
_authenticate, username, password
|
||||
)
|
||||
if user_id is None:
|
||||
return self._show_reauth_form(errors)
|
||||
|
||||
if self.reauth_entry.unique_id != user_id:
|
||||
return self.async_abort(reason="wrong_account")
|
||||
|
||||
data = {
|
||||
CONF_USERNAME: username,
|
||||
CONF_PASSWORD: user_input[CONF_PASSWORD],
|
||||
}
|
||||
self.hass.config_entries.async_update_entry(self.reauth_entry, data=data)
|
||||
await self.hass.config_entries.async_reload(self.reauth_entry.entry_id)
|
||||
return self.async_abort(reason="reauth_successful")
|
||||
|
||||
def _show_reauth_form(self, errors: dict[str, str]) -> FlowResult:
|
||||
"""Show the reauth form."""
|
||||
return self.async_show_form(
|
||||
step_id="reauth_confirm",
|
||||
data_schema=STEP_REAUTH_DATA_SCHEMA,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
|
||||
def _authenticate(username: str, password: str) -> tuple[str | None, dict[str, str]]:
|
||||
"""Authenticate with the Schlage API."""
|
||||
auth = pyschlage.Auth(username, password)
|
||||
auth.authenticate()
|
||||
# The user_id property will make a blocking call if it's not already
|
||||
# cached. To avoid blocking the event loop, we read it here.
|
||||
return auth.user_id
|
||||
user_id = None
|
||||
errors: dict[str, str] = {}
|
||||
try:
|
||||
auth = pyschlage.Auth(username, password)
|
||||
auth.authenticate()
|
||||
except NotAuthorizedError:
|
||||
errors["base"] = "invalid_auth"
|
||||
except Exception: # pylint: disable=broad-except
|
||||
LOGGER.exception("Unknown error")
|
||||
errors["base"] = "unknown"
|
||||
else:
|
||||
# The user_id property will make a blocking call if it's not already
|
||||
# cached. To avoid blocking the event loop, we read it here.
|
||||
user_id = auth.user_id
|
||||
return user_id, errors
|
||||
|
|
|
@ -5,10 +5,11 @@ import asyncio
|
|||
from dataclasses import dataclass
|
||||
|
||||
from pyschlage import Lock, Schlage
|
||||
from pyschlage.exceptions import Error as SchlageError
|
||||
from pyschlage.exceptions import Error as SchlageError, NotAuthorizedError
|
||||
from pyschlage.log import LockLog
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed
|
||||
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
|
||||
|
||||
from .const import DOMAIN, LOGGER, UPDATE_INTERVAL
|
||||
|
@ -43,6 +44,8 @@ class SchlageDataUpdateCoordinator(DataUpdateCoordinator[SchlageData]):
|
|||
"""Fetch the latest data from the Schlage API."""
|
||||
try:
|
||||
locks = await self.hass.async_add_executor_job(self.api.locks)
|
||||
except NotAuthorizedError as ex:
|
||||
raise ConfigEntryAuthFailed from ex
|
||||
except SchlageError as ex:
|
||||
raise UpdateFailed("Failed to refresh Schlage data") from ex
|
||||
lock_data = await asyncio.gather(
|
||||
|
@ -64,6 +67,8 @@ class SchlageDataUpdateCoordinator(DataUpdateCoordinator[SchlageData]):
|
|||
logs = previous_lock_data.logs
|
||||
try:
|
||||
logs = lock.logs()
|
||||
except NotAuthorizedError as ex:
|
||||
raise ConfigEntryAuthFailed from ex
|
||||
except SchlageError as ex:
|
||||
LOGGER.debug('Failed to read logs for lock "%s": %s', lock.name, ex)
|
||||
|
||||
|
|
|
@ -6,6 +6,13 @@
|
|||
"username": "[%key:common::config_flow::data::username%]",
|
||||
"password": "[%key:common::config_flow::data::password%]"
|
||||
}
|
||||
},
|
||||
"reauth_confirm": {
|
||||
"title": "[%key:common::config_flow::title::reauth%]",
|
||||
"description": "The Schlage integration needs to re-authenticate your account",
|
||||
"data": {
|
||||
"password": "[%key:common::config_flow::data::password%]"
|
||||
}
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
|
@ -13,7 +20,9 @@
|
|||
"unknown": "[%key:common::config_flow::error::unknown%]"
|
||||
},
|
||||
"abort": {
|
||||
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]"
|
||||
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]",
|
||||
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]",
|
||||
"wrong_account": "The user credentials provided do not match this Schlage account."
|
||||
}
|
||||
},
|
||||
"entity": {
|
||||
|
|
|
@ -54,14 +54,14 @@ def mock_setup_entry() -> Generator[AsyncMock, None, None]:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_schlage():
|
||||
def mock_schlage() -> Mock:
|
||||
"""Mock pyschlage.Schlage."""
|
||||
with patch("pyschlage.Schlage", autospec=True) as mock_schlage:
|
||||
yield mock_schlage.return_value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pyschlage_auth():
|
||||
def mock_pyschlage_auth() -> Mock:
|
||||
"""Mock pyschlage.Auth."""
|
||||
with patch("pyschlage.Auth", autospec=True) as mock_auth:
|
||||
mock_auth.return_value.user_id = "abc123"
|
||||
|
@ -69,7 +69,7 @@ def mock_pyschlage_auth():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_lock():
|
||||
def mock_lock() -> Mock:
|
||||
"""Mock Lock fixture."""
|
||||
mock_lock = create_autospec(Lock)
|
||||
mock_lock.configure_mock(
|
||||
|
|
|
@ -9,6 +9,8 @@ from homeassistant.components.schlage.const import DOMAIN
|
|||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("mock_setup_entry")
|
||||
|
||||
|
||||
|
@ -78,3 +80,94 @@ async def test_form_unknown(hass: HomeAssistant, mock_pyschlage_auth: Mock) -> N
|
|||
|
||||
assert result2["type"] == FlowResultType.FORM
|
||||
assert result2["errors"] == {"base": "unknown"}
|
||||
|
||||
|
||||
async def test_reauth(
|
||||
hass: HomeAssistant,
|
||||
mock_added_config_entry: MockConfigEntry,
|
||||
mock_setup_entry: AsyncMock,
|
||||
mock_pyschlage_auth: Mock,
|
||||
) -> None:
|
||||
"""Test reauth flow."""
|
||||
mock_added_config_entry.async_start_reauth(hass)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
flows = hass.config_entries.flow.async_progress()
|
||||
assert len(flows) == 1
|
||||
[result] = flows
|
||||
assert result["step_id"] == "reauth_confirm"
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"password": "new-password"},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
mock_pyschlage_auth.authenticate.assert_called_once_with()
|
||||
assert result2["type"] == FlowResultType.ABORT
|
||||
assert result2["reason"] == "reauth_successful"
|
||||
assert mock_added_config_entry.data == {
|
||||
"username": "asdf@asdf.com",
|
||||
"password": "new-password",
|
||||
}
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_reauth_invalid_auth(
|
||||
hass: HomeAssistant,
|
||||
mock_added_config_entry: MockConfigEntry,
|
||||
mock_setup_entry: AsyncMock,
|
||||
mock_pyschlage_auth: Mock,
|
||||
) -> None:
|
||||
"""Test reauth flow."""
|
||||
mock_added_config_entry.async_start_reauth(hass)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
flows = hass.config_entries.flow.async_progress()
|
||||
assert len(flows) == 1
|
||||
[result] = flows
|
||||
assert result["step_id"] == "reauth_confirm"
|
||||
|
||||
mock_pyschlage_auth.authenticate.reset_mock()
|
||||
mock_pyschlage_auth.authenticate.side_effect = NotAuthorizedError
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"password": "new-password"},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
mock_pyschlage_auth.authenticate.assert_called_once_with()
|
||||
assert result2["type"] == FlowResultType.FORM
|
||||
assert result2["errors"] == {"base": "invalid_auth"}
|
||||
|
||||
|
||||
async def test_reauth_wrong_account(
|
||||
hass: HomeAssistant,
|
||||
mock_added_config_entry: MockConfigEntry,
|
||||
mock_setup_entry: AsyncMock,
|
||||
mock_pyschlage_auth: Mock,
|
||||
) -> None:
|
||||
"""Test reauth flow."""
|
||||
mock_pyschlage_auth.user_id = "bad-user-id"
|
||||
mock_added_config_entry.async_start_reauth(hass)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
flows = hass.config_entries.flow.async_progress()
|
||||
assert len(flows) == 1
|
||||
[result] = flows
|
||||
assert result["step_id"] == "reauth_confirm"
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{"password": "new-password"},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
mock_pyschlage_auth.authenticate.assert_called_once_with()
|
||||
assert result2["type"] == FlowResultType.ABORT
|
||||
assert result2["reason"] == "wrong_account"
|
||||
assert mock_added_config_entry.data == {
|
||||
"username": "asdf@asdf.com",
|
||||
"password": "hunter2",
|
||||
}
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
from unittest.mock import Mock, patch
|
||||
|
||||
from pycognito.exceptions import WarrantException
|
||||
from pyschlage.exceptions import Error
|
||||
from pyschlage.exceptions import Error, NotAuthorizedError
|
||||
|
||||
from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
@ -43,6 +43,41 @@ async def test_update_data_fails(
|
|||
assert mock_config_entry.state is ConfigEntryState.SETUP_RETRY
|
||||
|
||||
|
||||
async def test_update_data_auth_error(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_pyschlage_auth: Mock,
|
||||
mock_schlage: Mock,
|
||||
) -> None:
|
||||
"""Test that we properly handle API errors."""
|
||||
mock_schlage.locks.side_effect = NotAuthorizedError
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert mock_schlage.locks.call_count == 1
|
||||
assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR
|
||||
|
||||
|
||||
async def test_update_data_get_logs_auth_error(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
mock_pyschlage_auth: Mock,
|
||||
mock_schlage: Mock,
|
||||
mock_lock: Mock,
|
||||
) -> None:
|
||||
"""Test that we properly handle API errors."""
|
||||
mock_schlage.locks.return_value = [mock_lock]
|
||||
mock_lock.logs.reset_mock()
|
||||
mock_lock.logs.side_effect = NotAuthorizedError
|
||||
mock_config_entry.add_to_hass(hass)
|
||||
await hass.config_entries.async_setup(mock_config_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert mock_schlage.locks.call_count == 1
|
||||
assert mock_config_entry.state is ConfigEntryState.SETUP_ERROR
|
||||
|
||||
|
||||
async def test_load_unload_config_entry(
|
||||
hass: HomeAssistant,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
|
|
Loading…
Reference in New Issue