Add reauth flow to webOS TV integration (#86168)
* Add reauth flow to webOS TV integration * Remove unnecessary elsepull/86178/head
parent
f2b348dbdf
commit
c40c37e9ee
|
@ -18,6 +18,7 @@ from homeassistant.const import (
|
|||
EVENT_HOMEASSISTANT_STOP,
|
||||
)
|
||||
from homeassistant.core import Event, HomeAssistant, ServiceCall
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed
|
||||
from homeassistant.helpers import config_validation as cv, discovery
|
||||
from homeassistant.helpers.dispatcher import async_dispatcher_send
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
@ -77,8 +78,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
|
||||
# Attempt a connection, but fail gracefully if tv is off for example.
|
||||
client = WebOsClient(host, key)
|
||||
with suppress(*WEBOSTV_EXCEPTIONS, WebOsTvPairError):
|
||||
await client.connect()
|
||||
with suppress(*WEBOSTV_EXCEPTIONS):
|
||||
try:
|
||||
await client.connect()
|
||||
except WebOsTvPairError as err:
|
||||
raise ConfigEntryAuthFailed(err) from err
|
||||
|
||||
# If pairing request accepted there will be no error
|
||||
# Update the stored key without triggering reauth
|
||||
update_client_key(hass, entry, client)
|
||||
|
||||
async def async_service_handler(service: ServiceCall) -> None:
|
||||
method = SERVICE_TO_METHOD[service.service]
|
||||
|
@ -141,6 +149,19 @@ async def async_control_connect(host: str, key: str | None) -> WebOsClient:
|
|||
return client
|
||||
|
||||
|
||||
def update_client_key(
|
||||
hass: HomeAssistant, entry: ConfigEntry, client: WebOsClient
|
||||
) -> None:
|
||||
"""Check and update stored client key if key has changed."""
|
||||
host = entry.data[CONF_HOST]
|
||||
key = entry.data[CONF_CLIENT_SECRET]
|
||||
|
||||
if client.client_key != key:
|
||||
_LOGGER.debug("Updating client key for host %s", host)
|
||||
data = {CONF_HOST: host, CONF_CLIENT_SECRET: client.client_key}
|
||||
hass.config_entries.async_update_entry(entry, data=data)
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload a config entry."""
|
||||
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Config flow to configure webostv component."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
@ -8,14 +9,14 @@ from urllib.parse import urlparse
|
|||
from aiowebostv import WebOsTvPairError
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries, data_entry_flow
|
||||
from homeassistant.components import ssdp
|
||||
from homeassistant.config_entries import ConfigEntry, ConfigFlow, OptionsFlow
|
||||
from homeassistant.const import CONF_CLIENT_SECRET, CONF_HOST, CONF_NAME
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.data_entry_flow import AbortFlow, FlowResult
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
|
||||
from . import async_control_connect
|
||||
from . import async_control_connect, update_client_key
|
||||
from .const import CONF_SOURCES, DEFAULT_NAME, DOMAIN, WEBOSTV_EXCEPTIONS
|
||||
from .helpers import async_get_sources
|
||||
|
||||
|
@ -30,7 +31,7 @@ DATA_SCHEMA = vol.Schema(
|
|||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
class FlowHandler(ConfigFlow, domain=DOMAIN):
|
||||
"""WebosTV configuration flow."""
|
||||
|
||||
VERSION = 1
|
||||
|
@ -40,12 +41,11 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
self._host: str = ""
|
||||
self._name: str = ""
|
||||
self._uuid: str | None = None
|
||||
self._entry: ConfigEntry | None = None
|
||||
|
||||
@staticmethod
|
||||
@callback
|
||||
def async_get_options_flow(
|
||||
config_entry: config_entries.ConfigEntry,
|
||||
) -> OptionsFlowHandler:
|
||||
def async_get_options_flow(config_entry: ConfigEntry) -> OptionsFlow:
|
||||
"""Get the options flow for this handler."""
|
||||
return OptionsFlowHandler(config_entry)
|
||||
|
||||
|
@ -78,7 +78,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
)
|
||||
self.hass.config_entries.async_update_entry(entry, unique_id=self._uuid)
|
||||
|
||||
raise data_entry_flow.AbortFlow("already_configured")
|
||||
raise AbortFlow("already_configured")
|
||||
|
||||
async def async_step_pairing(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
|
@ -129,11 +129,37 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
self._uuid = uuid
|
||||
return await self.async_step_pairing()
|
||||
|
||||
async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
|
||||
"""Perform reauth upon an WebOsTvPairError."""
|
||||
self._host = entry_data[CONF_HOST]
|
||||
self._entry = self.hass.config_entries.async_get_entry(self.context["entry_id"])
|
||||
return await self.async_step_reauth_confirm()
|
||||
|
||||
class OptionsFlowHandler(config_entries.OptionsFlow):
|
||||
async def async_step_reauth_confirm(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Dialog that informs the user that reauth is required."""
|
||||
assert self._entry is not None
|
||||
|
||||
if user_input is not None:
|
||||
try:
|
||||
client = await async_control_connect(self._host, None)
|
||||
except WebOsTvPairError:
|
||||
return self.async_abort(reason="error_pairing")
|
||||
except WEBOSTV_EXCEPTIONS:
|
||||
return self.async_abort(reason="reauth_unsuccessful")
|
||||
|
||||
update_client_key(self.hass, self._entry, client)
|
||||
await self.hass.config_entries.async_reload(self._entry.entry_id)
|
||||
return self.async_abort(reason="reauth_successful")
|
||||
|
||||
return self.async_show_form(step_id="reauth_confirm")
|
||||
|
||||
|
||||
class OptionsFlowHandler(OptionsFlow):
|
||||
"""Handle options."""
|
||||
|
||||
def __init__(self, config_entry: config_entries.ConfigEntry) -> None:
|
||||
def __init__(self, config_entry: ConfigEntry) -> None:
|
||||
"""Initialize options flow."""
|
||||
self.config_entry = config_entry
|
||||
self.options = config_entry.options
|
||||
|
|
|
@ -39,6 +39,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
|||
from homeassistant.helpers.restore_state import RestoreEntity
|
||||
from homeassistant.helpers.trigger import PluggableAction
|
||||
|
||||
from . import update_client_key
|
||||
from .const import (
|
||||
ATTR_PAYLOAD,
|
||||
ATTR_SOUND_OUTPUT,
|
||||
|
@ -73,18 +74,11 @@ SCAN_INTERVAL = timedelta(seconds=10)
|
|||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback
|
||||
) -> None:
|
||||
"""Set up the LG webOS Smart TV platform."""
|
||||
unique_id = config_entry.unique_id
|
||||
assert unique_id
|
||||
name = config_entry.title
|
||||
sources = config_entry.options.get(CONF_SOURCES)
|
||||
client = hass.data[DOMAIN][DATA_CONFIG_ENTRY][config_entry.entry_id]
|
||||
|
||||
async_add_entities([LgWebOSMediaPlayerEntity(client, name, sources, unique_id)])
|
||||
client = hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id]
|
||||
async_add_entities([LgWebOSMediaPlayerEntity(entry, client)])
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound="LgWebOSMediaPlayerEntity")
|
||||
|
@ -123,19 +117,14 @@ class LgWebOSMediaPlayerEntity(RestoreEntity, MediaPlayerEntity):
|
|||
|
||||
_attr_device_class = MediaPlayerDeviceClass.TV
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: WebOsClient,
|
||||
name: str,
|
||||
sources: list[str] | None,
|
||||
unique_id: str,
|
||||
) -> None:
|
||||
def __init__(self, entry: ConfigEntry, client: WebOsClient) -> None:
|
||||
"""Initialize the webos device."""
|
||||
self._entry = entry
|
||||
self._client = client
|
||||
self._attr_assumed_state = True
|
||||
self._attr_name = name
|
||||
self._attr_unique_id = unique_id
|
||||
self._sources = sources
|
||||
self._attr_name = entry.title
|
||||
self._attr_unique_id = entry.unique_id
|
||||
self._sources = entry.options.get(CONF_SOURCES)
|
||||
|
||||
# Assume that the TV is not paused
|
||||
self._paused = False
|
||||
|
@ -326,7 +315,12 @@ class LgWebOSMediaPlayerEntity(RestoreEntity, MediaPlayerEntity):
|
|||
return
|
||||
|
||||
with suppress(*WEBOSTV_EXCEPTIONS, WebOsTvPairError):
|
||||
await self._client.connect()
|
||||
try:
|
||||
await self._client.connect()
|
||||
except WebOsTvPairError:
|
||||
self._entry.async_start_reauth(self.hass)
|
||||
else:
|
||||
update_client_key(self.hass, self._entry, self._client)
|
||||
|
||||
@property
|
||||
def supported_features(self) -> MediaPlayerEntityFeature:
|
||||
|
|
|
@ -13,6 +13,10 @@
|
|||
"pairing": {
|
||||
"title": "webOS TV Pairing",
|
||||
"description": "Click submit and accept the pairing request on your TV.\n\n![Image](/static/images/config_webos.png)"
|
||||
},
|
||||
"reauth_confirm": {
|
||||
"title": "webOS TV Pairing",
|
||||
"description": "Click submit and accept the pairing request on your TV.\n\n![Image](/static/images/config_webos.png)"
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
|
@ -21,7 +25,9 @@
|
|||
"abort": {
|
||||
"error_pairing": "Connected to LG webOS TV but not paired",
|
||||
"already_in_progress": "[%key:common::config_flow::abort::already_in_progress%]",
|
||||
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]"
|
||||
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]",
|
||||
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]",
|
||||
"reauth_unsuccessful": "Re-authentication was unsuccessful, please turn on your TV and try again."
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
|
|
|
@ -3,7 +3,9 @@
|
|||
"abort": {
|
||||
"already_configured": "Device is already configured",
|
||||
"already_in_progress": "Configuration flow is already in progress",
|
||||
"error_pairing": "Connected to LG webOS TV but not paired"
|
||||
"error_pairing": "Connected to LG webOS TV but not paired",
|
||||
"reauth_successful": "Re-authentication was successful",
|
||||
"reauth_unsuccessful": "Re-authentication was unsuccessful, please turn on your TV and try again."
|
||||
},
|
||||
"error": {
|
||||
"cannot_connect": "Failed to connect, please turn on your TV or check ip address"
|
||||
|
@ -14,6 +16,10 @@
|
|||
"description": "Click submit and accept the pairing request on your TV.\n\n![Image](/static/images/config_webos.png)",
|
||||
"title": "webOS TV Pairing"
|
||||
},
|
||||
"reauth_confirm": {
|
||||
"description": "Click submit and accept the pairing request on your TV.\n\n![Image](/static/images/config_webos.png)",
|
||||
"title": "webOS TV Pairing"
|
||||
},
|
||||
"user": {
|
||||
"data": {
|
||||
"host": "Host",
|
||||
|
|
|
@ -9,11 +9,11 @@ from homeassistant import config_entries
|
|||
from homeassistant.components import ssdp
|
||||
from homeassistant.components.webostv.const import CONF_SOURCES, DOMAIN, LIVE_TV_APP_ID
|
||||
from homeassistant.config_entries import SOURCE_SSDP
|
||||
from homeassistant.const import CONF_HOST, CONF_NAME, CONF_SOURCE
|
||||
from homeassistant.const import CONF_CLIENT_SECRET, CONF_HOST, CONF_NAME, CONF_SOURCE
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
|
||||
from . import setup_webostv
|
||||
from .const import FAKE_UUID, HOST, MOCK_APPS, MOCK_INPUTS, TV_NAME
|
||||
from .const import CLIENT_KEY, FAKE_UUID, HOST, MOCK_APPS, MOCK_INPUTS, TV_NAME
|
||||
|
||||
MOCK_USER_CONFIG = {
|
||||
CONF_HOST: HOST,
|
||||
|
@ -289,3 +289,64 @@ async def test_form_abort_uuid_configured(hass, client):
|
|||
assert result["type"] == FlowResultType.ABORT
|
||||
assert result["reason"] == "already_configured"
|
||||
assert entry.data[CONF_HOST] == "new_host"
|
||||
|
||||
|
||||
async def test_reauth_successful(hass, client, monkeypatch):
|
||||
"""Test that the reauthorization is successful."""
|
||||
entry = await setup_webostv(hass)
|
||||
assert client
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={"source": config_entries.SOURCE_REAUTH, "entry_id": entry.entry_id},
|
||||
data=entry.data,
|
||||
)
|
||||
assert result["step_id"] == "reauth_confirm"
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["step_id"] == "reauth_confirm"
|
||||
assert entry.data[CONF_CLIENT_SECRET] == CLIENT_KEY
|
||||
|
||||
monkeypatch.setattr(client, "client_key", "new_key")
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], user_input={}
|
||||
)
|
||||
|
||||
assert result["type"] == FlowResultType.ABORT
|
||||
assert result["reason"] == "reauth_successful"
|
||||
assert entry.data[CONF_CLIENT_SECRET] == "new_key"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"side_effect,reason",
|
||||
[
|
||||
(WebOsTvPairError, "error_pairing"),
|
||||
(ConnectionRefusedError, "reauth_unsuccessful"),
|
||||
],
|
||||
)
|
||||
async def test_reauth_errors(hass, client, monkeypatch, side_effect, reason):
|
||||
"""Test reauthorization errors."""
|
||||
entry = await setup_webostv(hass)
|
||||
assert client
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={"source": config_entries.SOURCE_REAUTH, "entry_id": entry.entry_id},
|
||||
data=entry.data,
|
||||
)
|
||||
assert result["step_id"] == "reauth_confirm"
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["step_id"] == "reauth_confirm"
|
||||
|
||||
monkeypatch.setattr(client, "connect", Mock(side_effect=side_effect))
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], user_input={}
|
||||
)
|
||||
|
||||
assert result["type"] == FlowResultType.ABORT
|
||||
assert result["reason"] == reason
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
"""The tests for the LG webOS TV platform."""
|
||||
from unittest.mock import Mock
|
||||
|
||||
from aiowebostv import WebOsTvPairError
|
||||
|
||||
from homeassistant.components.webostv.const import DOMAIN
|
||||
from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntryState
|
||||
from homeassistant.const import CONF_CLIENT_SECRET
|
||||
|
||||
from . import setup_webostv
|
||||
|
||||
|
||||
async def test_reauth_setup_entry(hass, client, monkeypatch):
|
||||
"""Test reauth flow triggered by setup entry."""
|
||||
monkeypatch.setattr(client, "is_connected", Mock(return_value=False))
|
||||
monkeypatch.setattr(client, "connect", Mock(side_effect=WebOsTvPairError))
|
||||
entry = await setup_webostv(hass)
|
||||
|
||||
assert entry.state == ConfigEntryState.SETUP_ERROR
|
||||
|
||||
flows = hass.config_entries.flow.async_progress()
|
||||
assert len(flows) == 1
|
||||
|
||||
flow = flows[0]
|
||||
assert flow.get("step_id") == "reauth_confirm"
|
||||
assert flow.get("handler") == DOMAIN
|
||||
|
||||
assert "context" in flow
|
||||
assert flow["context"].get("source") == SOURCE_REAUTH
|
||||
assert flow["context"].get("entry_id") == entry.entry_id
|
||||
|
||||
|
||||
async def test_key_update_setup_entry(hass, client, monkeypatch):
|
||||
"""Test key update from setup entry."""
|
||||
monkeypatch.setattr(client, "client_key", "new_key")
|
||||
entry = await setup_webostv(hass)
|
||||
|
||||
assert entry.state == ConfigEntryState.LOADED
|
||||
assert entry.data[CONF_CLIENT_SECRET] == "new_key"
|
|
@ -4,6 +4,7 @@ from datetime import timedelta
|
|||
from http import HTTPStatus
|
||||
from unittest.mock import Mock
|
||||
|
||||
from aiowebostv import WebOsTvPairError
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import automation
|
||||
|
@ -37,6 +38,7 @@ from homeassistant.components.webostv.media_player import (
|
|||
SUPPORT_WEBOSTV,
|
||||
SUPPORT_WEBOSTV_VOLUME,
|
||||
)
|
||||
from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntryState
|
||||
from homeassistant.const import (
|
||||
ATTR_COMMAND,
|
||||
ATTR_DEVICE_CLASS,
|
||||
|
@ -763,3 +765,28 @@ async def test_get_image_https(
|
|||
content = await resp.read()
|
||||
|
||||
assert content == b"https_image"
|
||||
|
||||
|
||||
async def test_reauth_reconnect(hass, client, monkeypatch):
|
||||
"""Test reauth flow triggered by reconnect."""
|
||||
entry = await setup_webostv(hass)
|
||||
monkeypatch.setattr(client, "is_connected", Mock(return_value=False))
|
||||
monkeypatch.setattr(client, "connect", Mock(side_effect=WebOsTvPairError))
|
||||
|
||||
assert entry.state == ConfigEntryState.LOADED
|
||||
|
||||
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=20))
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert entry.state == ConfigEntryState.LOADED
|
||||
|
||||
flows = hass.config_entries.flow.async_progress()
|
||||
assert len(flows) == 1
|
||||
|
||||
flow = flows[0]
|
||||
assert flow.get("step_id") == "reauth_confirm"
|
||||
assert flow.get("handler") == DOMAIN
|
||||
|
||||
assert "context" in flow
|
||||
assert flow["context"].get("source") == SOURCE_REAUTH
|
||||
assert flow["context"].get("entry_id") == entry.entry_id
|
||||
|
|
Loading…
Reference in New Issue