Add reauth flow to webOS TV integration (#86168)

* Add reauth flow to webOS TV integration

* Remove unnecessary else
pull/86178/head
Shay Levy 2023-01-18 18:48:38 +02:00 committed by GitHub
parent f2b348dbdf
commit c40c37e9ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 217 additions and 37 deletions

View File

@ -18,6 +18,7 @@ from homeassistant.const import (
EVENT_HOMEASSISTANT_STOP,
)
from homeassistant.core import Event, HomeAssistant, ServiceCall
from homeassistant.exceptions import ConfigEntryAuthFailed
from homeassistant.helpers import config_validation as cv, discovery
from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.typing import ConfigType
@ -77,8 +78,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
# Attempt a connection, but fail gracefully if tv is off for example.
client = WebOsClient(host, key)
with suppress(*WEBOSTV_EXCEPTIONS, WebOsTvPairError):
await client.connect()
with suppress(*WEBOSTV_EXCEPTIONS):
try:
await client.connect()
except WebOsTvPairError as err:
raise ConfigEntryAuthFailed(err) from err
# If pairing request accepted there will be no error
# Update the stored key without triggering reauth
update_client_key(hass, entry, client)
async def async_service_handler(service: ServiceCall) -> None:
method = SERVICE_TO_METHOD[service.service]
@ -141,6 +149,19 @@ async def async_control_connect(host: str, key: str | None) -> WebOsClient:
return client
def update_client_key(
hass: HomeAssistant, entry: ConfigEntry, client: WebOsClient
) -> None:
"""Check and update stored client key if key has changed."""
host = entry.data[CONF_HOST]
key = entry.data[CONF_CLIENT_SECRET]
if client.client_key != key:
_LOGGER.debug("Updating client key for host %s", host)
data = {CONF_HOST: host, CONF_CLIENT_SECRET: client.client_key}
hass.config_entries.async_update_entry(entry, data=data)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)

View File

@ -1,6 +1,7 @@
"""Config flow to configure webostv component."""
from __future__ import annotations
from collections.abc import Mapping
import logging
from typing import Any
from urllib.parse import urlparse
@ -8,14 +9,14 @@ from urllib.parse import urlparse
from aiowebostv import WebOsTvPairError
import voluptuous as vol
from homeassistant import config_entries, data_entry_flow
from homeassistant.components import ssdp
from homeassistant.config_entries import ConfigEntry, ConfigFlow, OptionsFlow
from homeassistant.const import CONF_CLIENT_SECRET, CONF_HOST, CONF_NAME
from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.data_entry_flow import AbortFlow, FlowResult
from homeassistant.helpers import config_validation as cv
from . import async_control_connect
from . import async_control_connect, update_client_key
from .const import CONF_SOURCES, DEFAULT_NAME, DOMAIN, WEBOSTV_EXCEPTIONS
from .helpers import async_get_sources
@ -30,7 +31,7 @@ DATA_SCHEMA = vol.Schema(
_LOGGER = logging.getLogger(__name__)
class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
class FlowHandler(ConfigFlow, domain=DOMAIN):
"""WebosTV configuration flow."""
VERSION = 1
@ -40,12 +41,11 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
self._host: str = ""
self._name: str = ""
self._uuid: str | None = None
self._entry: ConfigEntry | None = None
@staticmethod
@callback
def async_get_options_flow(
config_entry: config_entries.ConfigEntry,
) -> OptionsFlowHandler:
def async_get_options_flow(config_entry: ConfigEntry) -> OptionsFlow:
"""Get the options flow for this handler."""
return OptionsFlowHandler(config_entry)
@ -78,7 +78,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
)
self.hass.config_entries.async_update_entry(entry, unique_id=self._uuid)
raise data_entry_flow.AbortFlow("already_configured")
raise AbortFlow("already_configured")
async def async_step_pairing(
self, user_input: dict[str, Any] | None = None
@ -129,11 +129,37 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
self._uuid = uuid
return await self.async_step_pairing()
async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
"""Perform reauth upon an WebOsTvPairError."""
self._host = entry_data[CONF_HOST]
self._entry = self.hass.config_entries.async_get_entry(self.context["entry_id"])
return await self.async_step_reauth_confirm()
class OptionsFlowHandler(config_entries.OptionsFlow):
async def async_step_reauth_confirm(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Dialog that informs the user that reauth is required."""
assert self._entry is not None
if user_input is not None:
try:
client = await async_control_connect(self._host, None)
except WebOsTvPairError:
return self.async_abort(reason="error_pairing")
except WEBOSTV_EXCEPTIONS:
return self.async_abort(reason="reauth_unsuccessful")
update_client_key(self.hass, self._entry, client)
await self.hass.config_entries.async_reload(self._entry.entry_id)
return self.async_abort(reason="reauth_successful")
return self.async_show_form(step_id="reauth_confirm")
class OptionsFlowHandler(OptionsFlow):
"""Handle options."""
def __init__(self, config_entry: config_entries.ConfigEntry) -> None:
def __init__(self, config_entry: ConfigEntry) -> None:
"""Initialize options flow."""
self.config_entry = config_entry
self.options = config_entry.options

View File

@ -39,6 +39,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.trigger import PluggableAction
from . import update_client_key
from .const import (
ATTR_PAYLOAD,
ATTR_SOUND_OUTPUT,
@ -73,18 +74,11 @@ SCAN_INTERVAL = timedelta(seconds=10)
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback
) -> None:
"""Set up the LG webOS Smart TV platform."""
unique_id = config_entry.unique_id
assert unique_id
name = config_entry.title
sources = config_entry.options.get(CONF_SOURCES)
client = hass.data[DOMAIN][DATA_CONFIG_ENTRY][config_entry.entry_id]
async_add_entities([LgWebOSMediaPlayerEntity(client, name, sources, unique_id)])
client = hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id]
async_add_entities([LgWebOSMediaPlayerEntity(entry, client)])
_T = TypeVar("_T", bound="LgWebOSMediaPlayerEntity")
@ -123,19 +117,14 @@ class LgWebOSMediaPlayerEntity(RestoreEntity, MediaPlayerEntity):
_attr_device_class = MediaPlayerDeviceClass.TV
def __init__(
self,
client: WebOsClient,
name: str,
sources: list[str] | None,
unique_id: str,
) -> None:
def __init__(self, entry: ConfigEntry, client: WebOsClient) -> None:
"""Initialize the webos device."""
self._entry = entry
self._client = client
self._attr_assumed_state = True
self._attr_name = name
self._attr_unique_id = unique_id
self._sources = sources
self._attr_name = entry.title
self._attr_unique_id = entry.unique_id
self._sources = entry.options.get(CONF_SOURCES)
# Assume that the TV is not paused
self._paused = False
@ -326,7 +315,12 @@ class LgWebOSMediaPlayerEntity(RestoreEntity, MediaPlayerEntity):
return
with suppress(*WEBOSTV_EXCEPTIONS, WebOsTvPairError):
await self._client.connect()
try:
await self._client.connect()
except WebOsTvPairError:
self._entry.async_start_reauth(self.hass)
else:
update_client_key(self.hass, self._entry, self._client)
@property
def supported_features(self) -> MediaPlayerEntityFeature:

View File

@ -13,6 +13,10 @@
"pairing": {
"title": "webOS TV Pairing",
"description": "Click submit and accept the pairing request on your TV.\n\n![Image](/static/images/config_webos.png)"
},
"reauth_confirm": {
"title": "webOS TV Pairing",
"description": "Click submit and accept the pairing request on your TV.\n\n![Image](/static/images/config_webos.png)"
}
},
"error": {
@ -21,7 +25,9 @@
"abort": {
"error_pairing": "Connected to LG webOS TV but not paired",
"already_in_progress": "[%key:common::config_flow::abort::already_in_progress%]",
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]"
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]",
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]",
"reauth_unsuccessful": "Re-authentication was unsuccessful, please turn on your TV and try again."
}
},
"options": {

View File

@ -3,7 +3,9 @@
"abort": {
"already_configured": "Device is already configured",
"already_in_progress": "Configuration flow is already in progress",
"error_pairing": "Connected to LG webOS TV but not paired"
"error_pairing": "Connected to LG webOS TV but not paired",
"reauth_successful": "Re-authentication was successful",
"reauth_unsuccessful": "Re-authentication was unsuccessful, please turn on your TV and try again."
},
"error": {
"cannot_connect": "Failed to connect, please turn on your TV or check ip address"
@ -14,6 +16,10 @@
"description": "Click submit and accept the pairing request on your TV.\n\n![Image](/static/images/config_webos.png)",
"title": "webOS TV Pairing"
},
"reauth_confirm": {
"description": "Click submit and accept the pairing request on your TV.\n\n![Image](/static/images/config_webos.png)",
"title": "webOS TV Pairing"
},
"user": {
"data": {
"host": "Host",

View File

@ -9,11 +9,11 @@ from homeassistant import config_entries
from homeassistant.components import ssdp
from homeassistant.components.webostv.const import CONF_SOURCES, DOMAIN, LIVE_TV_APP_ID
from homeassistant.config_entries import SOURCE_SSDP
from homeassistant.const import CONF_HOST, CONF_NAME, CONF_SOURCE
from homeassistant.const import CONF_CLIENT_SECRET, CONF_HOST, CONF_NAME, CONF_SOURCE
from homeassistant.data_entry_flow import FlowResultType
from . import setup_webostv
from .const import FAKE_UUID, HOST, MOCK_APPS, MOCK_INPUTS, TV_NAME
from .const import CLIENT_KEY, FAKE_UUID, HOST, MOCK_APPS, MOCK_INPUTS, TV_NAME
MOCK_USER_CONFIG = {
CONF_HOST: HOST,
@ -289,3 +289,64 @@ async def test_form_abort_uuid_configured(hass, client):
assert result["type"] == FlowResultType.ABORT
assert result["reason"] == "already_configured"
assert entry.data[CONF_HOST] == "new_host"
async def test_reauth_successful(hass, client, monkeypatch):
"""Test that the reauthorization is successful."""
entry = await setup_webostv(hass)
assert client
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_REAUTH, "entry_id": entry.entry_id},
data=entry.data,
)
assert result["step_id"] == "reauth_confirm"
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "reauth_confirm"
assert entry.data[CONF_CLIENT_SECRET] == CLIENT_KEY
monkeypatch.setattr(client, "client_key", "new_key")
result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input={}
)
assert result["type"] == FlowResultType.ABORT
assert result["reason"] == "reauth_successful"
assert entry.data[CONF_CLIENT_SECRET] == "new_key"
@pytest.mark.parametrize(
"side_effect,reason",
[
(WebOsTvPairError, "error_pairing"),
(ConnectionRefusedError, "reauth_unsuccessful"),
],
)
async def test_reauth_errors(hass, client, monkeypatch, side_effect, reason):
"""Test reauthorization errors."""
entry = await setup_webostv(hass)
assert client
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_REAUTH, "entry_id": entry.entry_id},
data=entry.data,
)
assert result["step_id"] == "reauth_confirm"
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "reauth_confirm"
monkeypatch.setattr(client, "connect", Mock(side_effect=side_effect))
result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input={}
)
assert result["type"] == FlowResultType.ABORT
assert result["reason"] == reason

View File

@ -0,0 +1,39 @@
"""The tests for the LG webOS TV platform."""
from unittest.mock import Mock
from aiowebostv import WebOsTvPairError
from homeassistant.components.webostv.const import DOMAIN
from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntryState
from homeassistant.const import CONF_CLIENT_SECRET
from . import setup_webostv
async def test_reauth_setup_entry(hass, client, monkeypatch):
"""Test reauth flow triggered by setup entry."""
monkeypatch.setattr(client, "is_connected", Mock(return_value=False))
monkeypatch.setattr(client, "connect", Mock(side_effect=WebOsTvPairError))
entry = await setup_webostv(hass)
assert entry.state == ConfigEntryState.SETUP_ERROR
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
flow = flows[0]
assert flow.get("step_id") == "reauth_confirm"
assert flow.get("handler") == DOMAIN
assert "context" in flow
assert flow["context"].get("source") == SOURCE_REAUTH
assert flow["context"].get("entry_id") == entry.entry_id
async def test_key_update_setup_entry(hass, client, monkeypatch):
"""Test key update from setup entry."""
monkeypatch.setattr(client, "client_key", "new_key")
entry = await setup_webostv(hass)
assert entry.state == ConfigEntryState.LOADED
assert entry.data[CONF_CLIENT_SECRET] == "new_key"

View File

@ -4,6 +4,7 @@ from datetime import timedelta
from http import HTTPStatus
from unittest.mock import Mock
from aiowebostv import WebOsTvPairError
import pytest
from homeassistant.components import automation
@ -37,6 +38,7 @@ from homeassistant.components.webostv.media_player import (
SUPPORT_WEBOSTV,
SUPPORT_WEBOSTV_VOLUME,
)
from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntryState
from homeassistant.const import (
ATTR_COMMAND,
ATTR_DEVICE_CLASS,
@ -763,3 +765,28 @@ async def test_get_image_https(
content = await resp.read()
assert content == b"https_image"
async def test_reauth_reconnect(hass, client, monkeypatch):
"""Test reauth flow triggered by reconnect."""
entry = await setup_webostv(hass)
monkeypatch.setattr(client, "is_connected", Mock(return_value=False))
monkeypatch.setattr(client, "connect", Mock(side_effect=WebOsTvPairError))
assert entry.state == ConfigEntryState.LOADED
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=20))
await hass.async_block_till_done()
assert entry.state == ConfigEntryState.LOADED
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
flow = flows[0]
assert flow.get("step_id") == "reauth_confirm"
assert flow.get("handler") == DOMAIN
assert "context" in flow
assert flow["context"].get("source") == SOURCE_REAUTH
assert flow["context"].get("entry_id") == entry.entry_id