Add myuplink reauth flow (#110587)

* WIP test

* WIP

* WIP Reauth flow. Test fail otherways OK.

* Minor adjustments to tests

* Merge

* Merge

* Next level...

* Cleanup according to review

* It works!

* Simplify setup

* Remove default

* Remove files from PR

* Add back test_init

* Add back test_sensor

* Adjust error message

---------

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
pull/110474/head^2
Åke Strandberg 2024-02-17 10:18:53 +01:00 committed by GitHub
parent d99a7e2825
commit f5dad1d312
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 158 additions and 31 deletions

View File

@ -6,6 +6,7 @@ from myuplink import MyUplinkAPI
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryAuthFailed
from homeassistant.helpers import (
aiohttp_client,
config_entry_oauth2_flow,
@ -13,7 +14,7 @@ from homeassistant.helpers import (
)
from .api import AsyncConfigEntryAuth
from .const import DOMAIN
from .const import DOMAIN, OAUTH2_SCOPES
from .coordinator import MyUplinkDataCoordinator
PLATFORMS: list[Platform] = [
@ -33,6 +34,10 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
)
)
session = config_entry_oauth2_flow.OAuth2Session(hass, config_entry, implementation)
if set(config_entry.data["token"]["scope"].split(" ")) != set(OAUTH2_SCOPES):
raise ConfigEntryAuthFailed("Incorrect OAuth2 scope")
auth = AsyncConfigEntryAuth(aiohttp_client.async_get_clientsession(hass), session)
# Setup MyUplinkAPI and coordinator for data fetch

View File

@ -1,7 +1,10 @@
"""Config flow for myUplink."""
from collections.abc import Mapping
import logging
from typing import Any
from homeassistant.config_entries import ConfigEntry
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers import config_entry_oauth2_flow
from .const import DOMAIN, OAUTH2_SCOPES
@ -14,6 +17,8 @@ class OAuth2FlowHandler(
DOMAIN = DOMAIN
config_entry_reauth: ConfigEntry | None = None
@property
def logger(self) -> logging.Logger:
"""Return logger."""
@ -23,3 +28,30 @@ class OAuth2FlowHandler(
def extra_authorize_data(self) -> dict[str, Any]:
"""Extra data that needs to be appended to the authorize url."""
return {"scope": " ".join(OAUTH2_SCOPES)}
async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
"""Perform reauth upon an API authentication error."""
self.config_entry_reauth = self.hass.config_entries.async_get_entry(
self.context["entry_id"]
)
return await self.async_step_reauth_confirm()
async def async_step_reauth_confirm(
self, user_input: Mapping[str, Any] | None = None
) -> FlowResult:
"""Dialog that informs the user that reauth is required."""
if user_input is None:
return self.async_show_form(
step_id="reauth_confirm",
)
return await self.async_step_user()
async def async_oauth_create_entry(self, data: dict) -> FlowResult:
"""Create or update the config entry."""
if self.config_entry_reauth:
return self.async_update_reload_and_abort(
self.config_entry_reauth,
data=data,
)
return await super().async_oauth_create_entry(data)

View File

@ -5,4 +5,4 @@ DOMAIN = "myuplink"
API_ENDPOINT = "https://api.myuplink.com"
OAUTH2_AUTHORIZE = "https://api.myuplink.com/oauth/authorize"
OAUTH2_TOKEN = "https://api.myuplink.com/oauth/token"
OAUTH2_SCOPES = ["READSYSTEM", "offline_access"]
OAUTH2_SCOPES = ["WRITESYSTEM", "READSYSTEM", "offline_access"]

View File

@ -3,6 +3,10 @@
"step": {
"pick_implementation": {
"title": "[%key:common::config_flow::title::oauth2_pick_implementation%]"
},
"reauth_confirm": {
"title": "[%key:common::config_flow::title::reauth%]",
"description": "The myUplink integration needs to re-authenticate your account"
}
},
"abort": {
@ -12,7 +16,8 @@
"missing_configuration": "[%key:common::config_flow::abort::oauth2_missing_configuration%]",
"authorize_url_timeout": "[%key:common::config_flow::abort::oauth2_authorize_url_timeout%]",
"no_url_available": "[%key:common::config_flow::abort::oauth2_no_url_available%]",
"user_rejected_authorize": "[%key:common::config_flow::abort::oauth2_user_rejected_authorize%]"
"user_rejected_authorize": "[%key:common::config_flow::abort::oauth2_user_rejected_authorize%]",
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]"
},
"create_entry": {
"default": "[%key:common::config_flow::create_entry::authenticated%]"

View File

@ -28,9 +28,9 @@ def mock_expires_at() -> float:
@pytest.fixture
def mock_config_entry(expires_at: int) -> MockConfigEntry:
def mock_config_entry(hass: HomeAssistant, expires_at: float) -> MockConfigEntry:
"""Return the default mocked config entry."""
return MockConfigEntry(
config_entry = MockConfigEntry(
version=1,
domain=DOMAIN,
title="myUplink test",
@ -38,7 +38,7 @@ def mock_config_entry(expires_at: int) -> MockConfigEntry:
"auth_implementation": DOMAIN,
"token": {
"access_token": "Fake_token",
"scope": "READSYSTEM offline",
"scope": "WRITESYSTEM READSYSTEM offline_access",
"expires_in": 86399,
"refresh_token": "3012bc9f-7a65-4240-b817-9154ffdcc30f",
"token_type": "Bearer",
@ -47,6 +47,8 @@ def mock_config_entry(expires_at: int) -> MockConfigEntry:
},
entry_id="myuplink_test",
)
config_entry.add_to_hass(hass)
return config_entry
@pytest.fixture(autouse=True)

View File

@ -7,7 +7,6 @@
'access_token': '**REDACTED**',
'expires_in': 86399,
'refresh_token': '**REDACTED**',
'scope': 'READSYSTEM offline',
'token_type': 'Bearer',
}),
}),

View File

@ -2,35 +2,24 @@
from unittest.mock import patch
import pytest
from homeassistant import config_entries
from homeassistant.components.application_credentials import (
ClientCredential,
async_import_client_credential,
)
from homeassistant.components.myuplink.const import (
DOMAIN,
OAUTH2_AUTHORIZE,
OAUTH2_TOKEN,
)
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
from homeassistant.helpers import config_entry_oauth2_flow
from homeassistant.setup import async_setup_component
CLIENT_ID = "1234"
CLIENT_SECRET = "5678"
from .const import CLIENT_ID
from tests.common import MockConfigEntry
from tests.test_util.aiohttp import AiohttpClientMocker
from tests.typing import ClientSessionGenerator
@pytest.fixture
async def setup_credentials(hass: HomeAssistant) -> None:
"""Fixture to setup credentials."""
assert await async_setup_component(hass, "application_credentials", {})
await async_import_client_credential(
hass,
DOMAIN,
ClientCredential(CLIENT_ID, CLIENT_SECRET),
)
REDIRECT_URL = "https://example.com/auth/external/callback"
CURRENT_SCOPE = "WRITESYSTEM READSYSTEM offline_access"
async def test_full_flow(
@ -42,21 +31,21 @@ async def test_full_flow(
) -> None:
"""Check full flow."""
result = await hass.config_entries.flow.async_init(
"myuplink", context={"source": config_entries.SOURCE_USER}
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
"redirect_uri": REDIRECT_URL,
},
)
assert result["url"] == (
f"{OAUTH2_AUTHORIZE}?response_type=code&client_id={CLIENT_ID}"
"&redirect_uri=https://example.com/auth/external/callback"
f"&redirect_uri={REDIRECT_URL}"
f"&state={state}"
"&scope=READSYSTEM+offline_access"
f"&scope={CURRENT_SCOPE.replace(' ', '+')}"
)
client = await hass_client_no_auth()
@ -75,9 +64,100 @@ async def test_full_flow(
)
with patch(
"homeassistant.components.myuplink.async_setup_entry", return_value=True
f"homeassistant.components.{DOMAIN}.async_setup_entry", return_value=True
) as mock_setup:
await hass.config_entries.flow.async_configure(result["flow_id"])
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
assert len(mock_setup.mock_calls) == 1
async def test_flow_reauth(
hass: HomeAssistant,
hass_client_no_auth: ClientSessionGenerator,
aioclient_mock: AiohttpClientMocker,
current_request_with_host: None,
setup_credentials: None,
mock_config_entry: MockConfigEntry,
expires_at: float,
) -> None:
"""Test reauth step."""
OLD_SCOPE = "READSYSTEM offline_access"
OLD_SCOPE_TOKEN = {
"auth_implementation": DOMAIN,
"token": {
"access_token": "Fake_token",
"scope": OLD_SCOPE,
"expires_in": 86399,
"refresh_token": "3012bc9f-7a65-4240-b817-9154ffdcc30f",
"token_type": "Bearer",
"expires_at": expires_at,
},
}
assert mock_config_entry.data["token"]["scope"] == CURRENT_SCOPE
assert hass.config_entries.async_update_entry(
mock_config_entry, data=OLD_SCOPE_TOKEN
)
assert mock_config_entry.data["token"]["scope"] == OLD_SCOPE
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={
"source": config_entries.SOURCE_REAUTH,
"entry_id": mock_config_entry.entry_id,
},
data=mock_config_entry.data,
)
assert result["step_id"] == "reauth_confirm"
result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input={}
)
assert result["step_id"] == "auth"
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": REDIRECT_URL,
},
)
assert result["url"] == (
f"{OAUTH2_AUTHORIZE}?response_type=code&client_id={CLIENT_ID}"
f"&redirect_uri={REDIRECT_URL}"
f"&state={state}"
f"&scope={CURRENT_SCOPE.replace(' ', '+')}"
)
client = await hass_client_no_auth()
resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")
assert resp.status == 200
assert resp.headers["content-type"] == "text/html; charset=utf-8"
aioclient_mock.post(
OAUTH2_TOKEN,
json={
"refresh_token": "updated-refresh-token",
"access_token": "updated-access-token",
"type": "Bearer",
"expires_in": "60",
"scope": CURRENT_SCOPE,
},
)
with patch(
f"homeassistant.components.{DOMAIN}.async_setup_entry", return_value=True
) as mock_setup:
result = await hass.config_entries.flow.async_configure(result["flow_id"])
await hass.async_block_till_done()
assert result.get("type") == FlowResultType.ABORT
assert result.get("reason") == "reauth_successful"
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
assert len(mock_setup.mock_calls) == 1
assert mock_config_entry.data["token"]["scope"] == CURRENT_SCOPE

View File

@ -22,5 +22,9 @@ async def test_diagnostics(
assert await get_diagnostics_for_config_entry(
hass, hass_client, mock_config_entry
) == snapshot(
exclude=paths("config_entry_data.token.expires_at", "myuplink_test.entry_id")
exclude=paths(
"config_entry_data.token.expires_at",
"myuplink_test.entry_id",
"config_entry_data.token.scope",
)
)