Add myuplink reauth flow (#110587)
* WIP test * WIP * WIP Reauth flow. Test fail otherways OK. * Minor adjustments to tests * Merge * Merge * Next level... * Cleanup according to review * It works! * Simplify setup * Remove default * Remove files from PR * Add back test_init * Add back test_sensor * Adjust error message --------- Co-authored-by: Martin Hjelmare <marhje52@gmail.com>pull/110474/head^2
parent
d99a7e2825
commit
f5dad1d312
|
@ -6,6 +6,7 @@ from myuplink import MyUplinkAPI
|
|||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import Platform
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed
|
||||
from homeassistant.helpers import (
|
||||
aiohttp_client,
|
||||
config_entry_oauth2_flow,
|
||||
|
@ -13,7 +14,7 @@ from homeassistant.helpers import (
|
|||
)
|
||||
|
||||
from .api import AsyncConfigEntryAuth
|
||||
from .const import DOMAIN
|
||||
from .const import DOMAIN, OAUTH2_SCOPES
|
||||
from .coordinator import MyUplinkDataCoordinator
|
||||
|
||||
PLATFORMS: list[Platform] = [
|
||||
|
@ -33,6 +34,10 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
|
|||
)
|
||||
)
|
||||
session = config_entry_oauth2_flow.OAuth2Session(hass, config_entry, implementation)
|
||||
|
||||
if set(config_entry.data["token"]["scope"].split(" ")) != set(OAUTH2_SCOPES):
|
||||
raise ConfigEntryAuthFailed("Incorrect OAuth2 scope")
|
||||
|
||||
auth = AsyncConfigEntryAuth(aiohttp_client.async_get_clientsession(hass), session)
|
||||
|
||||
# Setup MyUplinkAPI and coordinator for data fetch
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
"""Config flow for myUplink."""
|
||||
from collections.abc import Mapping
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.helpers import config_entry_oauth2_flow
|
||||
|
||||
from .const import DOMAIN, OAUTH2_SCOPES
|
||||
|
@ -14,6 +17,8 @@ class OAuth2FlowHandler(
|
|||
|
||||
DOMAIN = DOMAIN
|
||||
|
||||
config_entry_reauth: ConfigEntry | None = None
|
||||
|
||||
@property
|
||||
def logger(self) -> logging.Logger:
|
||||
"""Return logger."""
|
||||
|
@ -23,3 +28,30 @@ class OAuth2FlowHandler(
|
|||
def extra_authorize_data(self) -> dict[str, Any]:
|
||||
"""Extra data that needs to be appended to the authorize url."""
|
||||
return {"scope": " ".join(OAUTH2_SCOPES)}
|
||||
|
||||
async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
|
||||
"""Perform reauth upon an API authentication error."""
|
||||
self.config_entry_reauth = self.hass.config_entries.async_get_entry(
|
||||
self.context["entry_id"]
|
||||
)
|
||||
return await self.async_step_reauth_confirm()
|
||||
|
||||
async def async_step_reauth_confirm(
|
||||
self, user_input: Mapping[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Dialog that informs the user that reauth is required."""
|
||||
if user_input is None:
|
||||
return self.async_show_form(
|
||||
step_id="reauth_confirm",
|
||||
)
|
||||
|
||||
return await self.async_step_user()
|
||||
|
||||
async def async_oauth_create_entry(self, data: dict) -> FlowResult:
|
||||
"""Create or update the config entry."""
|
||||
if self.config_entry_reauth:
|
||||
return self.async_update_reload_and_abort(
|
||||
self.config_entry_reauth,
|
||||
data=data,
|
||||
)
|
||||
return await super().async_oauth_create_entry(data)
|
||||
|
|
|
@ -5,4 +5,4 @@ DOMAIN = "myuplink"
|
|||
API_ENDPOINT = "https://api.myuplink.com"
|
||||
OAUTH2_AUTHORIZE = "https://api.myuplink.com/oauth/authorize"
|
||||
OAUTH2_TOKEN = "https://api.myuplink.com/oauth/token"
|
||||
OAUTH2_SCOPES = ["READSYSTEM", "offline_access"]
|
||||
OAUTH2_SCOPES = ["WRITESYSTEM", "READSYSTEM", "offline_access"]
|
||||
|
|
|
@ -3,6 +3,10 @@
|
|||
"step": {
|
||||
"pick_implementation": {
|
||||
"title": "[%key:common::config_flow::title::oauth2_pick_implementation%]"
|
||||
},
|
||||
"reauth_confirm": {
|
||||
"title": "[%key:common::config_flow::title::reauth%]",
|
||||
"description": "The myUplink integration needs to re-authenticate your account"
|
||||
}
|
||||
},
|
||||
"abort": {
|
||||
|
@ -12,7 +16,8 @@
|
|||
"missing_configuration": "[%key:common::config_flow::abort::oauth2_missing_configuration%]",
|
||||
"authorize_url_timeout": "[%key:common::config_flow::abort::oauth2_authorize_url_timeout%]",
|
||||
"no_url_available": "[%key:common::config_flow::abort::oauth2_no_url_available%]",
|
||||
"user_rejected_authorize": "[%key:common::config_flow::abort::oauth2_user_rejected_authorize%]"
|
||||
"user_rejected_authorize": "[%key:common::config_flow::abort::oauth2_user_rejected_authorize%]",
|
||||
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]"
|
||||
},
|
||||
"create_entry": {
|
||||
"default": "[%key:common::config_flow::create_entry::authenticated%]"
|
||||
|
|
|
@ -28,9 +28,9 @@ def mock_expires_at() -> float:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_entry(expires_at: int) -> MockConfigEntry:
|
||||
def mock_config_entry(hass: HomeAssistant, expires_at: float) -> MockConfigEntry:
|
||||
"""Return the default mocked config entry."""
|
||||
return MockConfigEntry(
|
||||
config_entry = MockConfigEntry(
|
||||
version=1,
|
||||
domain=DOMAIN,
|
||||
title="myUplink test",
|
||||
|
@ -38,7 +38,7 @@ def mock_config_entry(expires_at: int) -> MockConfigEntry:
|
|||
"auth_implementation": DOMAIN,
|
||||
"token": {
|
||||
"access_token": "Fake_token",
|
||||
"scope": "READSYSTEM offline",
|
||||
"scope": "WRITESYSTEM READSYSTEM offline_access",
|
||||
"expires_in": 86399,
|
||||
"refresh_token": "3012bc9f-7a65-4240-b817-9154ffdcc30f",
|
||||
"token_type": "Bearer",
|
||||
|
@ -47,6 +47,8 @@ def mock_config_entry(expires_at: int) -> MockConfigEntry:
|
|||
},
|
||||
entry_id="myuplink_test",
|
||||
)
|
||||
config_entry.add_to_hass(hass)
|
||||
return config_entry
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
'access_token': '**REDACTED**',
|
||||
'expires_in': 86399,
|
||||
'refresh_token': '**REDACTED**',
|
||||
'scope': 'READSYSTEM offline',
|
||||
'token_type': 'Bearer',
|
||||
}),
|
||||
}),
|
||||
|
|
|
@ -2,35 +2,24 @@
|
|||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.application_credentials import (
|
||||
ClientCredential,
|
||||
async_import_client_credential,
|
||||
)
|
||||
from homeassistant.components.myuplink.const import (
|
||||
DOMAIN,
|
||||
OAUTH2_AUTHORIZE,
|
||||
OAUTH2_TOKEN,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
from homeassistant.helpers import config_entry_oauth2_flow
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
CLIENT_ID = "1234"
|
||||
CLIENT_SECRET = "5678"
|
||||
from .const import CLIENT_ID
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.test_util.aiohttp import AiohttpClientMocker
|
||||
from tests.typing import ClientSessionGenerator
|
||||
|
||||
@pytest.fixture
|
||||
async def setup_credentials(hass: HomeAssistant) -> None:
|
||||
"""Fixture to setup credentials."""
|
||||
assert await async_setup_component(hass, "application_credentials", {})
|
||||
await async_import_client_credential(
|
||||
hass,
|
||||
DOMAIN,
|
||||
ClientCredential(CLIENT_ID, CLIENT_SECRET),
|
||||
)
|
||||
REDIRECT_URL = "https://example.com/auth/external/callback"
|
||||
CURRENT_SCOPE = "WRITESYSTEM READSYSTEM offline_access"
|
||||
|
||||
|
||||
async def test_full_flow(
|
||||
|
@ -42,21 +31,21 @@ async def test_full_flow(
|
|||
) -> None:
|
||||
"""Check full flow."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
"myuplink", context={"source": config_entries.SOURCE_USER}
|
||||
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
"redirect_uri": REDIRECT_URL,
|
||||
},
|
||||
)
|
||||
|
||||
assert result["url"] == (
|
||||
f"{OAUTH2_AUTHORIZE}?response_type=code&client_id={CLIENT_ID}"
|
||||
"&redirect_uri=https://example.com/auth/external/callback"
|
||||
f"&redirect_uri={REDIRECT_URL}"
|
||||
f"&state={state}"
|
||||
"&scope=READSYSTEM+offline_access"
|
||||
f"&scope={CURRENT_SCOPE.replace(' ', '+')}"
|
||||
)
|
||||
|
||||
client = await hass_client_no_auth()
|
||||
|
@ -75,9 +64,100 @@ async def test_full_flow(
|
|||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.myuplink.async_setup_entry", return_value=True
|
||||
f"homeassistant.components.{DOMAIN}.async_setup_entry", return_value=True
|
||||
) as mock_setup:
|
||||
await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
|
||||
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
|
||||
assert len(mock_setup.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_flow_reauth(
|
||||
hass: HomeAssistant,
|
||||
hass_client_no_auth: ClientSessionGenerator,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
current_request_with_host: None,
|
||||
setup_credentials: None,
|
||||
mock_config_entry: MockConfigEntry,
|
||||
expires_at: float,
|
||||
) -> None:
|
||||
"""Test reauth step."""
|
||||
|
||||
OLD_SCOPE = "READSYSTEM offline_access"
|
||||
OLD_SCOPE_TOKEN = {
|
||||
"auth_implementation": DOMAIN,
|
||||
"token": {
|
||||
"access_token": "Fake_token",
|
||||
"scope": OLD_SCOPE,
|
||||
"expires_in": 86399,
|
||||
"refresh_token": "3012bc9f-7a65-4240-b817-9154ffdcc30f",
|
||||
"token_type": "Bearer",
|
||||
"expires_at": expires_at,
|
||||
},
|
||||
}
|
||||
assert mock_config_entry.data["token"]["scope"] == CURRENT_SCOPE
|
||||
assert hass.config_entries.async_update_entry(
|
||||
mock_config_entry, data=OLD_SCOPE_TOKEN
|
||||
)
|
||||
assert mock_config_entry.data["token"]["scope"] == OLD_SCOPE
|
||||
|
||||
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={
|
||||
"source": config_entries.SOURCE_REAUTH,
|
||||
"entry_id": mock_config_entry.entry_id,
|
||||
},
|
||||
data=mock_config_entry.data,
|
||||
)
|
||||
|
||||
assert result["step_id"] == "reauth_confirm"
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], user_input={}
|
||||
)
|
||||
assert result["step_id"] == "auth"
|
||||
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": REDIRECT_URL,
|
||||
},
|
||||
)
|
||||
assert result["url"] == (
|
||||
f"{OAUTH2_AUTHORIZE}?response_type=code&client_id={CLIENT_ID}"
|
||||
f"&redirect_uri={REDIRECT_URL}"
|
||||
f"&state={state}"
|
||||
f"&scope={CURRENT_SCOPE.replace(' ', '+')}"
|
||||
)
|
||||
|
||||
client = await hass_client_no_auth()
|
||||
resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||
assert resp.status == 200
|
||||
assert resp.headers["content-type"] == "text/html; charset=utf-8"
|
||||
|
||||
aioclient_mock.post(
|
||||
OAUTH2_TOKEN,
|
||||
json={
|
||||
"refresh_token": "updated-refresh-token",
|
||||
"access_token": "updated-access-token",
|
||||
"type": "Bearer",
|
||||
"expires_in": "60",
|
||||
"scope": CURRENT_SCOPE,
|
||||
},
|
||||
)
|
||||
|
||||
with patch(
|
||||
f"homeassistant.components.{DOMAIN}.async_setup_entry", return_value=True
|
||||
) as mock_setup:
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result.get("type") == FlowResultType.ABORT
|
||||
assert result.get("reason") == "reauth_successful"
|
||||
|
||||
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
|
||||
assert len(mock_setup.mock_calls) == 1
|
||||
assert mock_config_entry.data["token"]["scope"] == CURRENT_SCOPE
|
||||
|
|
|
@ -22,5 +22,9 @@ async def test_diagnostics(
|
|||
assert await get_diagnostics_for_config_entry(
|
||||
hass, hass_client, mock_config_entry
|
||||
) == snapshot(
|
||||
exclude=paths("config_entry_data.token.expires_at", "myuplink_test.entry_id")
|
||||
exclude=paths(
|
||||
"config_entry_data.token.expires_at",
|
||||
"myuplink_test.entry_id",
|
||||
"config_entry_data.token.scope",
|
||||
)
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue