Media source to verify domain to avoid KeyError (#67137)

pull/67149/head
Paulus Schoutsen 2022-02-23 16:22:39 -08:00 committed by GitHub
parent fff74c66ae
commit a5383e40eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 7 deletions

View File

@ -85,11 +85,16 @@ def _get_media_item(
) -> MediaSourceItem:
"""Return media item."""
if media_content_id:
return MediaSourceItem.from_uri(hass, media_content_id)
item = MediaSourceItem.from_uri(hass, media_content_id)
else:
# We default to our own domain if its only one registered
domain = None if len(hass.data[DOMAIN]) > 1 else DOMAIN
return MediaSourceItem(hass, domain, "")
# We default to our own domain if its only one registered
domain = None if len(hass.data[DOMAIN]) > 1 else DOMAIN
return MediaSourceItem(hass, domain, "")
if item.domain is not None and item.domain not in hass.data[DOMAIN]:
raise ValueError("Unknown media source")
return item
@bind_hass
@ -106,7 +111,7 @@ async def async_browse_media(
try:
item = await _get_media_item(hass, media_content_id).async_browse()
except ValueError as err:
raise BrowseError("Not a media source item") from err
raise BrowseError(str(err)) from err
if content_filter is None or item.children is None:
return item
@ -128,7 +133,7 @@ async def async_resolve_media(hass: HomeAssistant, media_content_id: str) -> Pla
try:
item = _get_media_item(hass, media_content_id)
except ValueError as err:
raise Unresolvable("Not a media source item") from err
raise Unresolvable(str(err)) from err
return await item.async_resolve()

View File

@ -98,6 +98,10 @@ async def test_async_unresolve_media(hass):
with pytest.raises(media_source.Unresolvable):
await media_source.async_resolve_media(hass, "invalid")
# Test invalid media source
with pytest.raises(media_source.Unresolvable):
await media_source.async_resolve_media(hass, "media-source://media_source2")
async def test_websocket_browse_media(hass, hass_ws_client):
"""Test browse media websocket."""

View File

@ -54,7 +54,7 @@ async def test_async_browse_media(hass):
# Test invalid base
with pytest.raises(media_source.BrowseError) as excinfo:
await media_source.async_browse_media(hass, f"{const.URI_SCHEME}{DOMAIN}/")
assert str(excinfo.value) == "Not a media source item"
assert str(excinfo.value) == "Invalid media source URI"
# Test successful listing
media = await media_source.async_browse_media(