diff --git a/homeassistant/components/webostv/__init__.py b/homeassistant/components/webostv/__init__.py index 7852ca568a0..6e960ceb143 100644 --- a/homeassistant/components/webostv/__init__.py +++ b/homeassistant/components/webostv/__init__.py @@ -18,6 +18,7 @@ from homeassistant.const import ( EVENT_HOMEASSISTANT_STOP, ) from homeassistant.core import Event, HomeAssistant, ServiceCall +from homeassistant.exceptions import ConfigEntryAuthFailed from homeassistant.helpers import config_validation as cv, discovery from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.typing import ConfigType @@ -77,8 +78,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: # Attempt a connection, but fail gracefully if tv is off for example. client = WebOsClient(host, key) - with suppress(*WEBOSTV_EXCEPTIONS, WebOsTvPairError): - await client.connect() + with suppress(*WEBOSTV_EXCEPTIONS): + try: + await client.connect() + except WebOsTvPairError as err: + raise ConfigEntryAuthFailed(err) from err + + # If pairing request accepted there will be no error + # Update the stored key without triggering reauth + update_client_key(hass, entry, client) async def async_service_handler(service: ServiceCall) -> None: method = SERVICE_TO_METHOD[service.service] @@ -141,6 +149,19 @@ async def async_control_connect(host: str, key: str | None) -> WebOsClient: return client +def update_client_key( + hass: HomeAssistant, entry: ConfigEntry, client: WebOsClient +) -> None: + """Check and update stored client key if key has changed.""" + host = entry.data[CONF_HOST] + key = entry.data[CONF_CLIENT_SECRET] + + if client.client_key != key: + _LOGGER.debug("Updating client key for host %s", host) + data = {CONF_HOST: host, CONF_CLIENT_SECRET: client.client_key} + hass.config_entries.async_update_entry(entry, data=data) + + async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) diff --git a/homeassistant/components/webostv/config_flow.py b/homeassistant/components/webostv/config_flow.py index ebf032498fa..1669e5a4c89 100644 --- a/homeassistant/components/webostv/config_flow.py +++ b/homeassistant/components/webostv/config_flow.py @@ -1,6 +1,7 @@ """Config flow to configure webostv component.""" from __future__ import annotations +from collections.abc import Mapping import logging from typing import Any from urllib.parse import urlparse @@ -8,14 +9,14 @@ from urllib.parse import urlparse from aiowebostv import WebOsTvPairError import voluptuous as vol -from homeassistant import config_entries, data_entry_flow from homeassistant.components import ssdp +from homeassistant.config_entries import ConfigEntry, ConfigFlow, OptionsFlow from homeassistant.const import CONF_CLIENT_SECRET, CONF_HOST, CONF_NAME from homeassistant.core import callback -from homeassistant.data_entry_flow import FlowResult +from homeassistant.data_entry_flow import AbortFlow, FlowResult from homeassistant.helpers import config_validation as cv -from . import async_control_connect +from . import async_control_connect, update_client_key from .const import CONF_SOURCES, DEFAULT_NAME, DOMAIN, WEBOSTV_EXCEPTIONS from .helpers import async_get_sources @@ -30,7 +31,7 @@ DATA_SCHEMA = vol.Schema( _LOGGER = logging.getLogger(__name__) -class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN): +class FlowHandler(ConfigFlow, domain=DOMAIN): """WebosTV configuration flow.""" VERSION = 1 @@ -40,12 +41,11 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN): self._host: str = "" self._name: str = "" self._uuid: str | None = None + self._entry: ConfigEntry | None = None @staticmethod @callback - def async_get_options_flow( - config_entry: config_entries.ConfigEntry, - ) -> OptionsFlowHandler: + def async_get_options_flow(config_entry: ConfigEntry) -> OptionsFlow: """Get the options flow for this handler.""" return OptionsFlowHandler(config_entry) @@ -78,7 +78,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN): ) self.hass.config_entries.async_update_entry(entry, unique_id=self._uuid) - raise data_entry_flow.AbortFlow("already_configured") + raise AbortFlow("already_configured") async def async_step_pairing( self, user_input: dict[str, Any] | None = None @@ -129,11 +129,37 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN): self._uuid = uuid return await self.async_step_pairing() + async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult: + """Perform reauth upon an WebOsTvPairError.""" + self._host = entry_data[CONF_HOST] + self._entry = self.hass.config_entries.async_get_entry(self.context["entry_id"]) + return await self.async_step_reauth_confirm() -class OptionsFlowHandler(config_entries.OptionsFlow): + async def async_step_reauth_confirm( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: + """Dialog that informs the user that reauth is required.""" + assert self._entry is not None + + if user_input is not None: + try: + client = await async_control_connect(self._host, None) + except WebOsTvPairError: + return self.async_abort(reason="error_pairing") + except WEBOSTV_EXCEPTIONS: + return self.async_abort(reason="reauth_unsuccessful") + + update_client_key(self.hass, self._entry, client) + await self.hass.config_entries.async_reload(self._entry.entry_id) + return self.async_abort(reason="reauth_successful") + + return self.async_show_form(step_id="reauth_confirm") + + +class OptionsFlowHandler(OptionsFlow): """Handle options.""" - def __init__(self, config_entry: config_entries.ConfigEntry) -> None: + def __init__(self, config_entry: ConfigEntry) -> None: """Initialize options flow.""" self.config_entry = config_entry self.options = config_entry.options diff --git a/homeassistant/components/webostv/media_player.py b/homeassistant/components/webostv/media_player.py index 53c7fb66825..36af5ef893f 100644 --- a/homeassistant/components/webostv/media_player.py +++ b/homeassistant/components/webostv/media_player.py @@ -39,6 +39,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.trigger import PluggableAction +from . import update_client_key from .const import ( ATTR_PAYLOAD, ATTR_SOUND_OUTPUT, @@ -73,18 +74,11 @@ SCAN_INTERVAL = timedelta(seconds=10) async def async_setup_entry( - hass: HomeAssistant, - config_entry: ConfigEntry, - async_add_entities: AddEntitiesCallback, + hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback ) -> None: """Set up the LG webOS Smart TV platform.""" - unique_id = config_entry.unique_id - assert unique_id - name = config_entry.title - sources = config_entry.options.get(CONF_SOURCES) - client = hass.data[DOMAIN][DATA_CONFIG_ENTRY][config_entry.entry_id] - - async_add_entities([LgWebOSMediaPlayerEntity(client, name, sources, unique_id)]) + client = hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id] + async_add_entities([LgWebOSMediaPlayerEntity(entry, client)]) _T = TypeVar("_T", bound="LgWebOSMediaPlayerEntity") @@ -123,19 +117,14 @@ class LgWebOSMediaPlayerEntity(RestoreEntity, MediaPlayerEntity): _attr_device_class = MediaPlayerDeviceClass.TV - def __init__( - self, - client: WebOsClient, - name: str, - sources: list[str] | None, - unique_id: str, - ) -> None: + def __init__(self, entry: ConfigEntry, client: WebOsClient) -> None: """Initialize the webos device.""" + self._entry = entry self._client = client self._attr_assumed_state = True - self._attr_name = name - self._attr_unique_id = unique_id - self._sources = sources + self._attr_name = entry.title + self._attr_unique_id = entry.unique_id + self._sources = entry.options.get(CONF_SOURCES) # Assume that the TV is not paused self._paused = False @@ -326,7 +315,12 @@ class LgWebOSMediaPlayerEntity(RestoreEntity, MediaPlayerEntity): return with suppress(*WEBOSTV_EXCEPTIONS, WebOsTvPairError): - await self._client.connect() + try: + await self._client.connect() + except WebOsTvPairError: + self._entry.async_start_reauth(self.hass) + else: + update_client_key(self.hass, self._entry, self._client) @property def supported_features(self) -> MediaPlayerEntityFeature: diff --git a/homeassistant/components/webostv/strings.json b/homeassistant/components/webostv/strings.json index 21e46e8e304..c623effe22b 100644 --- a/homeassistant/components/webostv/strings.json +++ b/homeassistant/components/webostv/strings.json @@ -13,6 +13,10 @@ "pairing": { "title": "webOS TV Pairing", "description": "Click submit and accept the pairing request on your TV.\n\n![Image](/static/images/config_webos.png)" + }, + "reauth_confirm": { + "title": "webOS TV Pairing", + "description": "Click submit and accept the pairing request on your TV.\n\n![Image](/static/images/config_webos.png)" } }, "error": { @@ -21,7 +25,9 @@ "abort": { "error_pairing": "Connected to LG webOS TV but not paired", "already_in_progress": "[%key:common::config_flow::abort::already_in_progress%]", - "already_configured": "[%key:common::config_flow::abort::already_configured_device%]" + "already_configured": "[%key:common::config_flow::abort::already_configured_device%]", + "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]", + "reauth_unsuccessful": "Re-authentication was unsuccessful, please turn on your TV and try again." } }, "options": { diff --git a/homeassistant/components/webostv/translations/en.json b/homeassistant/components/webostv/translations/en.json index dc0c8433151..87f9dd6c84b 100644 --- a/homeassistant/components/webostv/translations/en.json +++ b/homeassistant/components/webostv/translations/en.json @@ -3,7 +3,9 @@ "abort": { "already_configured": "Device is already configured", "already_in_progress": "Configuration flow is already in progress", - "error_pairing": "Connected to LG webOS TV but not paired" + "error_pairing": "Connected to LG webOS TV but not paired", + "reauth_successful": "Re-authentication was successful", + "reauth_unsuccessful": "Re-authentication was unsuccessful, please turn on your TV and try again." }, "error": { "cannot_connect": "Failed to connect, please turn on your TV or check ip address" @@ -14,6 +16,10 @@ "description": "Click submit and accept the pairing request on your TV.\n\n![Image](/static/images/config_webos.png)", "title": "webOS TV Pairing" }, + "reauth_confirm": { + "description": "Click submit and accept the pairing request on your TV.\n\n![Image](/static/images/config_webos.png)", + "title": "webOS TV Pairing" + }, "user": { "data": { "host": "Host", diff --git a/tests/components/webostv/test_config_flow.py b/tests/components/webostv/test_config_flow.py index cdb995de8ca..952307d9c26 100644 --- a/tests/components/webostv/test_config_flow.py +++ b/tests/components/webostv/test_config_flow.py @@ -9,11 +9,11 @@ from homeassistant import config_entries from homeassistant.components import ssdp from homeassistant.components.webostv.const import CONF_SOURCES, DOMAIN, LIVE_TV_APP_ID from homeassistant.config_entries import SOURCE_SSDP -from homeassistant.const import CONF_HOST, CONF_NAME, CONF_SOURCE +from homeassistant.const import CONF_CLIENT_SECRET, CONF_HOST, CONF_NAME, CONF_SOURCE from homeassistant.data_entry_flow import FlowResultType from . import setup_webostv -from .const import FAKE_UUID, HOST, MOCK_APPS, MOCK_INPUTS, TV_NAME +from .const import CLIENT_KEY, FAKE_UUID, HOST, MOCK_APPS, MOCK_INPUTS, TV_NAME MOCK_USER_CONFIG = { CONF_HOST: HOST, @@ -289,3 +289,64 @@ async def test_form_abort_uuid_configured(hass, client): assert result["type"] == FlowResultType.ABORT assert result["reason"] == "already_configured" assert entry.data[CONF_HOST] == "new_host" + + +async def test_reauth_successful(hass, client, monkeypatch): + """Test that the reauthorization is successful.""" + entry = await setup_webostv(hass) + assert client + + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_REAUTH, "entry_id": entry.entry_id}, + data=entry.data, + ) + assert result["step_id"] == "reauth_confirm" + + result = await hass.config_entries.flow.async_configure(result["flow_id"]) + + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "reauth_confirm" + assert entry.data[CONF_CLIENT_SECRET] == CLIENT_KEY + + monkeypatch.setattr(client, "client_key", "new_key") + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input={} + ) + + assert result["type"] == FlowResultType.ABORT + assert result["reason"] == "reauth_successful" + assert entry.data[CONF_CLIENT_SECRET] == "new_key" + + +@pytest.mark.parametrize( + "side_effect,reason", + [ + (WebOsTvPairError, "error_pairing"), + (ConnectionRefusedError, "reauth_unsuccessful"), + ], +) +async def test_reauth_errors(hass, client, monkeypatch, side_effect, reason): + """Test reauthorization errors.""" + entry = await setup_webostv(hass) + assert client + + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_REAUTH, "entry_id": entry.entry_id}, + data=entry.data, + ) + assert result["step_id"] == "reauth_confirm" + + result = await hass.config_entries.flow.async_configure(result["flow_id"]) + + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "reauth_confirm" + + monkeypatch.setattr(client, "connect", Mock(side_effect=side_effect)) + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input={} + ) + + assert result["type"] == FlowResultType.ABORT + assert result["reason"] == reason diff --git a/tests/components/webostv/test_init.py b/tests/components/webostv/test_init.py new file mode 100644 index 00000000000..e48bb9d80fd --- /dev/null +++ b/tests/components/webostv/test_init.py @@ -0,0 +1,39 @@ +"""The tests for the LG webOS TV platform.""" +from unittest.mock import Mock + +from aiowebostv import WebOsTvPairError + +from homeassistant.components.webostv.const import DOMAIN +from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntryState +from homeassistant.const import CONF_CLIENT_SECRET + +from . import setup_webostv + + +async def test_reauth_setup_entry(hass, client, monkeypatch): + """Test reauth flow triggered by setup entry.""" + monkeypatch.setattr(client, "is_connected", Mock(return_value=False)) + monkeypatch.setattr(client, "connect", Mock(side_effect=WebOsTvPairError)) + entry = await setup_webostv(hass) + + assert entry.state == ConfigEntryState.SETUP_ERROR + + flows = hass.config_entries.flow.async_progress() + assert len(flows) == 1 + + flow = flows[0] + assert flow.get("step_id") == "reauth_confirm" + assert flow.get("handler") == DOMAIN + + assert "context" in flow + assert flow["context"].get("source") == SOURCE_REAUTH + assert flow["context"].get("entry_id") == entry.entry_id + + +async def test_key_update_setup_entry(hass, client, monkeypatch): + """Test key update from setup entry.""" + monkeypatch.setattr(client, "client_key", "new_key") + entry = await setup_webostv(hass) + + assert entry.state == ConfigEntryState.LOADED + assert entry.data[CONF_CLIENT_SECRET] == "new_key" diff --git a/tests/components/webostv/test_media_player.py b/tests/components/webostv/test_media_player.py index e4e2e2ba45f..f12c07c66c9 100644 --- a/tests/components/webostv/test_media_player.py +++ b/tests/components/webostv/test_media_player.py @@ -4,6 +4,7 @@ from datetime import timedelta from http import HTTPStatus from unittest.mock import Mock +from aiowebostv import WebOsTvPairError import pytest from homeassistant.components import automation @@ -37,6 +38,7 @@ from homeassistant.components.webostv.media_player import ( SUPPORT_WEBOSTV, SUPPORT_WEBOSTV_VOLUME, ) +from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntryState from homeassistant.const import ( ATTR_COMMAND, ATTR_DEVICE_CLASS, @@ -763,3 +765,28 @@ async def test_get_image_https( content = await resp.read() assert content == b"https_image" + + +async def test_reauth_reconnect(hass, client, monkeypatch): + """Test reauth flow triggered by reconnect.""" + entry = await setup_webostv(hass) + monkeypatch.setattr(client, "is_connected", Mock(return_value=False)) + monkeypatch.setattr(client, "connect", Mock(side_effect=WebOsTvPairError)) + + assert entry.state == ConfigEntryState.LOADED + + async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=20)) + await hass.async_block_till_done() + + assert entry.state == ConfigEntryState.LOADED + + flows = hass.config_entries.flow.async_progress() + assert len(flows) == 1 + + flow = flows[0] + assert flow.get("step_id") == "reauth_confirm" + assert flow.get("handler") == DOMAIN + + assert "context" in flow + assert flow["context"].get("source") == SOURCE_REAUTH + assert flow["context"].get("entry_id") == entry.entry_id