Add myuplink reauth flow (#110587)
* WIP test * WIP * WIP Reauth flow. Test fail otherways OK. * Minor adjustments to tests * Merge * Merge * Next level... * Cleanup according to review * It works! * Simplify setup * Remove default * Remove files from PR * Add back test_init * Add back test_sensor * Adjust error message --------- Co-authored-by: Martin Hjelmare <marhje52@gmail.com>pull/110474/head^2
parent
d99a7e2825
commit
f5dad1d312
|
@ -6,6 +6,7 @@ from myuplink import MyUplinkAPI
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.const import Platform
|
from homeassistant.const import Platform
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
from homeassistant.exceptions import ConfigEntryAuthFailed
|
||||||
from homeassistant.helpers import (
|
from homeassistant.helpers import (
|
||||||
aiohttp_client,
|
aiohttp_client,
|
||||||
config_entry_oauth2_flow,
|
config_entry_oauth2_flow,
|
||||||
|
@ -13,7 +14,7 @@ from homeassistant.helpers import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .api import AsyncConfigEntryAuth
|
from .api import AsyncConfigEntryAuth
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN, OAUTH2_SCOPES
|
||||||
from .coordinator import MyUplinkDataCoordinator
|
from .coordinator import MyUplinkDataCoordinator
|
||||||
|
|
||||||
PLATFORMS: list[Platform] = [
|
PLATFORMS: list[Platform] = [
|
||||||
|
@ -33,6 +34,10 @@ async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> b
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
session = config_entry_oauth2_flow.OAuth2Session(hass, config_entry, implementation)
|
session = config_entry_oauth2_flow.OAuth2Session(hass, config_entry, implementation)
|
||||||
|
|
||||||
|
if set(config_entry.data["token"]["scope"].split(" ")) != set(OAUTH2_SCOPES):
|
||||||
|
raise ConfigEntryAuthFailed("Incorrect OAuth2 scope")
|
||||||
|
|
||||||
auth = AsyncConfigEntryAuth(aiohttp_client.async_get_clientsession(hass), session)
|
auth = AsyncConfigEntryAuth(aiohttp_client.async_get_clientsession(hass), session)
|
||||||
|
|
||||||
# Setup MyUplinkAPI and coordinator for data fetch
|
# Setup MyUplinkAPI and coordinator for data fetch
|
||||||
|
|
|
@ -1,7 +1,10 @@
|
||||||
"""Config flow for myUplink."""
|
"""Config flow for myUplink."""
|
||||||
|
from collections.abc import Mapping
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.data_entry_flow import FlowResult
|
||||||
from homeassistant.helpers import config_entry_oauth2_flow
|
from homeassistant.helpers import config_entry_oauth2_flow
|
||||||
|
|
||||||
from .const import DOMAIN, OAUTH2_SCOPES
|
from .const import DOMAIN, OAUTH2_SCOPES
|
||||||
|
@ -14,6 +17,8 @@ class OAuth2FlowHandler(
|
||||||
|
|
||||||
DOMAIN = DOMAIN
|
DOMAIN = DOMAIN
|
||||||
|
|
||||||
|
config_entry_reauth: ConfigEntry | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def logger(self) -> logging.Logger:
|
def logger(self) -> logging.Logger:
|
||||||
"""Return logger."""
|
"""Return logger."""
|
||||||
|
@ -23,3 +28,30 @@ class OAuth2FlowHandler(
|
||||||
def extra_authorize_data(self) -> dict[str, Any]:
|
def extra_authorize_data(self) -> dict[str, Any]:
|
||||||
"""Extra data that needs to be appended to the authorize url."""
|
"""Extra data that needs to be appended to the authorize url."""
|
||||||
return {"scope": " ".join(OAUTH2_SCOPES)}
|
return {"scope": " ".join(OAUTH2_SCOPES)}
|
||||||
|
|
||||||
|
async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
|
||||||
|
"""Perform reauth upon an API authentication error."""
|
||||||
|
self.config_entry_reauth = self.hass.config_entries.async_get_entry(
|
||||||
|
self.context["entry_id"]
|
||||||
|
)
|
||||||
|
return await self.async_step_reauth_confirm()
|
||||||
|
|
||||||
|
async def async_step_reauth_confirm(
|
||||||
|
self, user_input: Mapping[str, Any] | None = None
|
||||||
|
) -> FlowResult:
|
||||||
|
"""Dialog that informs the user that reauth is required."""
|
||||||
|
if user_input is None:
|
||||||
|
return self.async_show_form(
|
||||||
|
step_id="reauth_confirm",
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self.async_step_user()
|
||||||
|
|
||||||
|
async def async_oauth_create_entry(self, data: dict) -> FlowResult:
|
||||||
|
"""Create or update the config entry."""
|
||||||
|
if self.config_entry_reauth:
|
||||||
|
return self.async_update_reload_and_abort(
|
||||||
|
self.config_entry_reauth,
|
||||||
|
data=data,
|
||||||
|
)
|
||||||
|
return await super().async_oauth_create_entry(data)
|
||||||
|
|
|
@ -5,4 +5,4 @@ DOMAIN = "myuplink"
|
||||||
API_ENDPOINT = "https://api.myuplink.com"
|
API_ENDPOINT = "https://api.myuplink.com"
|
||||||
OAUTH2_AUTHORIZE = "https://api.myuplink.com/oauth/authorize"
|
OAUTH2_AUTHORIZE = "https://api.myuplink.com/oauth/authorize"
|
||||||
OAUTH2_TOKEN = "https://api.myuplink.com/oauth/token"
|
OAUTH2_TOKEN = "https://api.myuplink.com/oauth/token"
|
||||||
OAUTH2_SCOPES = ["READSYSTEM", "offline_access"]
|
OAUTH2_SCOPES = ["WRITESYSTEM", "READSYSTEM", "offline_access"]
|
||||||
|
|
|
@ -3,6 +3,10 @@
|
||||||
"step": {
|
"step": {
|
||||||
"pick_implementation": {
|
"pick_implementation": {
|
||||||
"title": "[%key:common::config_flow::title::oauth2_pick_implementation%]"
|
"title": "[%key:common::config_flow::title::oauth2_pick_implementation%]"
|
||||||
|
},
|
||||||
|
"reauth_confirm": {
|
||||||
|
"title": "[%key:common::config_flow::title::reauth%]",
|
||||||
|
"description": "The myUplink integration needs to re-authenticate your account"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"abort": {
|
"abort": {
|
||||||
|
@ -12,7 +16,8 @@
|
||||||
"missing_configuration": "[%key:common::config_flow::abort::oauth2_missing_configuration%]",
|
"missing_configuration": "[%key:common::config_flow::abort::oauth2_missing_configuration%]",
|
||||||
"authorize_url_timeout": "[%key:common::config_flow::abort::oauth2_authorize_url_timeout%]",
|
"authorize_url_timeout": "[%key:common::config_flow::abort::oauth2_authorize_url_timeout%]",
|
||||||
"no_url_available": "[%key:common::config_flow::abort::oauth2_no_url_available%]",
|
"no_url_available": "[%key:common::config_flow::abort::oauth2_no_url_available%]",
|
||||||
"user_rejected_authorize": "[%key:common::config_flow::abort::oauth2_user_rejected_authorize%]"
|
"user_rejected_authorize": "[%key:common::config_flow::abort::oauth2_user_rejected_authorize%]",
|
||||||
|
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]"
|
||||||
},
|
},
|
||||||
"create_entry": {
|
"create_entry": {
|
||||||
"default": "[%key:common::config_flow::create_entry::authenticated%]"
|
"default": "[%key:common::config_flow::create_entry::authenticated%]"
|
||||||
|
|
|
@ -28,9 +28,9 @@ def mock_expires_at() -> float:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_config_entry(expires_at: int) -> MockConfigEntry:
|
def mock_config_entry(hass: HomeAssistant, expires_at: float) -> MockConfigEntry:
|
||||||
"""Return the default mocked config entry."""
|
"""Return the default mocked config entry."""
|
||||||
return MockConfigEntry(
|
config_entry = MockConfigEntry(
|
||||||
version=1,
|
version=1,
|
||||||
domain=DOMAIN,
|
domain=DOMAIN,
|
||||||
title="myUplink test",
|
title="myUplink test",
|
||||||
|
@ -38,7 +38,7 @@ def mock_config_entry(expires_at: int) -> MockConfigEntry:
|
||||||
"auth_implementation": DOMAIN,
|
"auth_implementation": DOMAIN,
|
||||||
"token": {
|
"token": {
|
||||||
"access_token": "Fake_token",
|
"access_token": "Fake_token",
|
||||||
"scope": "READSYSTEM offline",
|
"scope": "WRITESYSTEM READSYSTEM offline_access",
|
||||||
"expires_in": 86399,
|
"expires_in": 86399,
|
||||||
"refresh_token": "3012bc9f-7a65-4240-b817-9154ffdcc30f",
|
"refresh_token": "3012bc9f-7a65-4240-b817-9154ffdcc30f",
|
||||||
"token_type": "Bearer",
|
"token_type": "Bearer",
|
||||||
|
@ -47,6 +47,8 @@ def mock_config_entry(expires_at: int) -> MockConfigEntry:
|
||||||
},
|
},
|
||||||
entry_id="myuplink_test",
|
entry_id="myuplink_test",
|
||||||
)
|
)
|
||||||
|
config_entry.add_to_hass(hass)
|
||||||
|
return config_entry
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
'access_token': '**REDACTED**',
|
'access_token': '**REDACTED**',
|
||||||
'expires_in': 86399,
|
'expires_in': 86399,
|
||||||
'refresh_token': '**REDACTED**',
|
'refresh_token': '**REDACTED**',
|
||||||
'scope': 'READSYSTEM offline',
|
|
||||||
'token_type': 'Bearer',
|
'token_type': 'Bearer',
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
|
|
|
@ -2,35 +2,24 @@
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
from homeassistant.components.application_credentials import (
|
|
||||||
ClientCredential,
|
|
||||||
async_import_client_credential,
|
|
||||||
)
|
|
||||||
from homeassistant.components.myuplink.const import (
|
from homeassistant.components.myuplink.const import (
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
OAUTH2_AUTHORIZE,
|
OAUTH2_AUTHORIZE,
|
||||||
OAUTH2_TOKEN,
|
OAUTH2_TOKEN,
|
||||||
)
|
)
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.data_entry_flow import FlowResultType
|
||||||
from homeassistant.helpers import config_entry_oauth2_flow
|
from homeassistant.helpers import config_entry_oauth2_flow
|
||||||
from homeassistant.setup import async_setup_component
|
|
||||||
|
|
||||||
CLIENT_ID = "1234"
|
from .const import CLIENT_ID
|
||||||
CLIENT_SECRET = "5678"
|
|
||||||
|
|
||||||
|
from tests.common import MockConfigEntry
|
||||||
|
from tests.test_util.aiohttp import AiohttpClientMocker
|
||||||
|
from tests.typing import ClientSessionGenerator
|
||||||
|
|
||||||
@pytest.fixture
|
REDIRECT_URL = "https://example.com/auth/external/callback"
|
||||||
async def setup_credentials(hass: HomeAssistant) -> None:
|
CURRENT_SCOPE = "WRITESYSTEM READSYSTEM offline_access"
|
||||||
"""Fixture to setup credentials."""
|
|
||||||
assert await async_setup_component(hass, "application_credentials", {})
|
|
||||||
await async_import_client_credential(
|
|
||||||
hass,
|
|
||||||
DOMAIN,
|
|
||||||
ClientCredential(CLIENT_ID, CLIENT_SECRET),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_full_flow(
|
async def test_full_flow(
|
||||||
|
@ -42,21 +31,21 @@ async def test_full_flow(
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Check full flow."""
|
"""Check full flow."""
|
||||||
result = await hass.config_entries.flow.async_init(
|
result = await hass.config_entries.flow.async_init(
|
||||||
"myuplink", context={"source": config_entries.SOURCE_USER}
|
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
)
|
)
|
||||||
state = config_entry_oauth2_flow._encode_jwt(
|
state = config_entry_oauth2_flow._encode_jwt(
|
||||||
hass,
|
hass,
|
||||||
{
|
{
|
||||||
"flow_id": result["flow_id"],
|
"flow_id": result["flow_id"],
|
||||||
"redirect_uri": "https://example.com/auth/external/callback",
|
"redirect_uri": REDIRECT_URL,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert result["url"] == (
|
assert result["url"] == (
|
||||||
f"{OAUTH2_AUTHORIZE}?response_type=code&client_id={CLIENT_ID}"
|
f"{OAUTH2_AUTHORIZE}?response_type=code&client_id={CLIENT_ID}"
|
||||||
"&redirect_uri=https://example.com/auth/external/callback"
|
f"&redirect_uri={REDIRECT_URL}"
|
||||||
f"&state={state}"
|
f"&state={state}"
|
||||||
"&scope=READSYSTEM+offline_access"
|
f"&scope={CURRENT_SCOPE.replace(' ', '+')}"
|
||||||
)
|
)
|
||||||
|
|
||||||
client = await hass_client_no_auth()
|
client = await hass_client_no_auth()
|
||||||
|
@ -75,9 +64,100 @@ async def test_full_flow(
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.myuplink.async_setup_entry", return_value=True
|
f"homeassistant.components.{DOMAIN}.async_setup_entry", return_value=True
|
||||||
) as mock_setup:
|
) as mock_setup:
|
||||||
await hass.config_entries.flow.async_configure(result["flow_id"])
|
await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||||
|
|
||||||
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
|
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
|
||||||
assert len(mock_setup.mock_calls) == 1
|
assert len(mock_setup.mock_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_flow_reauth(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_client_no_auth: ClientSessionGenerator,
|
||||||
|
aioclient_mock: AiohttpClientMocker,
|
||||||
|
current_request_with_host: None,
|
||||||
|
setup_credentials: None,
|
||||||
|
mock_config_entry: MockConfigEntry,
|
||||||
|
expires_at: float,
|
||||||
|
) -> None:
|
||||||
|
"""Test reauth step."""
|
||||||
|
|
||||||
|
OLD_SCOPE = "READSYSTEM offline_access"
|
||||||
|
OLD_SCOPE_TOKEN = {
|
||||||
|
"auth_implementation": DOMAIN,
|
||||||
|
"token": {
|
||||||
|
"access_token": "Fake_token",
|
||||||
|
"scope": OLD_SCOPE,
|
||||||
|
"expires_in": 86399,
|
||||||
|
"refresh_token": "3012bc9f-7a65-4240-b817-9154ffdcc30f",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_at": expires_at,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
assert mock_config_entry.data["token"]["scope"] == CURRENT_SCOPE
|
||||||
|
assert hass.config_entries.async_update_entry(
|
||||||
|
mock_config_entry, data=OLD_SCOPE_TOKEN
|
||||||
|
)
|
||||||
|
assert mock_config_entry.data["token"]["scope"] == OLD_SCOPE
|
||||||
|
|
||||||
|
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN,
|
||||||
|
context={
|
||||||
|
"source": config_entries.SOURCE_REAUTH,
|
||||||
|
"entry_id": mock_config_entry.entry_id,
|
||||||
|
},
|
||||||
|
data=mock_config_entry.data,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["step_id"] == "reauth_confirm"
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], user_input={}
|
||||||
|
)
|
||||||
|
assert result["step_id"] == "auth"
|
||||||
|
|
||||||
|
state = config_entry_oauth2_flow._encode_jwt(
|
||||||
|
hass,
|
||||||
|
{
|
||||||
|
"flow_id": result["flow_id"],
|
||||||
|
"redirect_uri": REDIRECT_URL,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert result["url"] == (
|
||||||
|
f"{OAUTH2_AUTHORIZE}?response_type=code&client_id={CLIENT_ID}"
|
||||||
|
f"&redirect_uri={REDIRECT_URL}"
|
||||||
|
f"&state={state}"
|
||||||
|
f"&scope={CURRENT_SCOPE.replace(' ', '+')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
client = await hass_client_no_auth()
|
||||||
|
resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert resp.headers["content-type"] == "text/html; charset=utf-8"
|
||||||
|
|
||||||
|
aioclient_mock.post(
|
||||||
|
OAUTH2_TOKEN,
|
||||||
|
json={
|
||||||
|
"refresh_token": "updated-refresh-token",
|
||||||
|
"access_token": "updated-access-token",
|
||||||
|
"type": "Bearer",
|
||||||
|
"expires_in": "60",
|
||||||
|
"scope": CURRENT_SCOPE,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
f"homeassistant.components.{DOMAIN}.async_setup_entry", return_value=True
|
||||||
|
) as mock_setup:
|
||||||
|
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert result.get("type") == FlowResultType.ABORT
|
||||||
|
assert result.get("reason") == "reauth_successful"
|
||||||
|
|
||||||
|
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
|
||||||
|
assert len(mock_setup.mock_calls) == 1
|
||||||
|
assert mock_config_entry.data["token"]["scope"] == CURRENT_SCOPE
|
||||||
|
|
|
@ -22,5 +22,9 @@ async def test_diagnostics(
|
||||||
assert await get_diagnostics_for_config_entry(
|
assert await get_diagnostics_for_config_entry(
|
||||||
hass, hass_client, mock_config_entry
|
hass, hass_client, mock_config_entry
|
||||||
) == snapshot(
|
) == snapshot(
|
||||||
exclude=paths("config_entry_data.token.expires_at", "myuplink_test.entry_id")
|
exclude=paths(
|
||||||
|
"config_entry_data.token.expires_at",
|
||||||
|
"myuplink_test.entry_id",
|
||||||
|
"config_entry_data.token.scope",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue