From a5383e40ebb81e84f37e9e6b8134440f202e923a Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Wed, 23 Feb 2022 16:22:39 -0800 Subject: [PATCH] Media source to verify domain to avoid KeyError (#67137) --- .../components/media_source/__init__.py | 17 +++++++++++------ tests/components/media_source/test_init.py | 4 ++++ tests/components/netatmo/test_media_source.py | 2 +- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/homeassistant/components/media_source/__init__.py b/homeassistant/components/media_source/__init__.py index e2bd1b4903b..77b254dcf9d 100644 --- a/homeassistant/components/media_source/__init__.py +++ b/homeassistant/components/media_source/__init__.py @@ -85,11 +85,16 @@ def _get_media_item( ) -> MediaSourceItem: """Return media item.""" if media_content_id: - return MediaSourceItem.from_uri(hass, media_content_id) + item = MediaSourceItem.from_uri(hass, media_content_id) + else: + # We default to our own domain if its only one registered + domain = None if len(hass.data[DOMAIN]) > 1 else DOMAIN + return MediaSourceItem(hass, domain, "") - # We default to our own domain if its only one registered - domain = None if len(hass.data[DOMAIN]) > 1 else DOMAIN - return MediaSourceItem(hass, domain, "") + if item.domain is not None and item.domain not in hass.data[DOMAIN]: + raise ValueError("Unknown media source") + + return item @bind_hass @@ -106,7 +111,7 @@ async def async_browse_media( try: item = await _get_media_item(hass, media_content_id).async_browse() except ValueError as err: - raise BrowseError("Not a media source item") from err + raise BrowseError(str(err)) from err if content_filter is None or item.children is None: return item @@ -128,7 +133,7 @@ async def async_resolve_media(hass: HomeAssistant, media_content_id: str) -> Pla try: item = _get_media_item(hass, media_content_id) except ValueError as err: - raise Unresolvable("Not a media source item") from err + raise Unresolvable(str(err)) from err return await item.async_resolve() diff --git a/tests/components/media_source/test_init.py b/tests/components/media_source/test_init.py index e36ccdac931..319ef295be3 100644 --- a/tests/components/media_source/test_init.py +++ b/tests/components/media_source/test_init.py @@ -98,6 +98,10 @@ async def test_async_unresolve_media(hass): with pytest.raises(media_source.Unresolvable): await media_source.async_resolve_media(hass, "invalid") + # Test invalid media source + with pytest.raises(media_source.Unresolvable): + await media_source.async_resolve_media(hass, "media-source://media_source2") + async def test_websocket_browse_media(hass, hass_ws_client): """Test browse media websocket.""" diff --git a/tests/components/netatmo/test_media_source.py b/tests/components/netatmo/test_media_source.py index c4741672186..db1a79145b4 100644 --- a/tests/components/netatmo/test_media_source.py +++ b/tests/components/netatmo/test_media_source.py @@ -54,7 +54,7 @@ async def test_async_browse_media(hass): # Test invalid base with pytest.raises(media_source.BrowseError) as excinfo: await media_source.async_browse_media(hass, f"{const.URI_SCHEME}{DOMAIN}/") - assert str(excinfo.value) == "Not a media source item" + assert str(excinfo.value) == "Invalid media source URI" # Test successful listing media = await media_source.async_browse_media(