From b1c2cde40bf69d51aa5cfd6c2331c8cf890ad698 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Joakim=20S=C3=B8rensen?= Date: Tue, 26 Jan 2021 15:53:21 +0100 Subject: [PATCH] Changes to filename and path validation (#45529) Co-authored-by: Paulus Schoutsen --- .../components/downloader/__init__.py | 20 +++++++++---- homeassistant/components/lovelace/__init__.py | 3 +- .../components/media_source/local_source.py | 14 ++++++---- .../components/python_script/__init__.py | 5 ++-- homeassistant/helpers/config_validation.py | 8 ++++-- homeassistant/helpers/deprecation.py | 23 +++++++++++++++ homeassistant/util/__init__.py | 23 +++++++++++++++ .../media_source/test_local_source.py | 2 +- tests/helpers/test_deprecation.py | 20 ++++++++++++- tests/util/test_init.py | 28 +++++++++++++++++++ 10 files changed, 127 insertions(+), 19 deletions(-) diff --git a/homeassistant/components/downloader/__init__.py b/homeassistant/components/downloader/__init__.py index 0c87f04e3ab..94617ce43aa 100644 --- a/homeassistant/components/downloader/__init__.py +++ b/homeassistant/components/downloader/__init__.py @@ -9,7 +9,7 @@ import voluptuous as vol from homeassistant.const import HTTP_OK import homeassistant.helpers.config_validation as cv -from homeassistant.util import sanitize_filename +from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path _LOGGER = logging.getLogger(__name__) @@ -70,8 +70,8 @@ def setup(hass, config): overwrite = service.data.get(ATTR_OVERWRITE) - if subdir: - subdir = sanitize_filename(subdir) + # Check the path + raise_if_invalid_path(subdir) final_path = None @@ -101,8 +101,8 @@ def setup(hass, config): if not filename: filename = "ha_download" - # Remove stuff to ruin paths - filename = sanitize_filename(filename) + # Check the filename + raise_if_invalid_filename(filename) # Do we want to download to subdir, create if needed if subdir: @@ -148,6 +148,16 @@ def setup(hass, config): {"url": url, "filename": filename}, ) + # Remove file if we started downloading but failed + if final_path and os.path.isfile(final_path): + os.remove(final_path) + except ValueError: + _LOGGER.exception("Invalid value") + hass.bus.fire( + f"{DOMAIN}_{DOWNLOAD_FAILED_EVENT}", + {"url": url, "filename": filename}, + ) + # Remove file if we started downloading but failed if final_path and os.path.isfile(final_path): os.remove(final_path) diff --git a/homeassistant/components/lovelace/__init__.py b/homeassistant/components/lovelace/__init__.py index 7d0fe6574b9..99b00a92289 100644 --- a/homeassistant/components/lovelace/__init__.py +++ b/homeassistant/components/lovelace/__init__.py @@ -12,7 +12,6 @@ from homeassistant.helpers import collection, config_validation as cv from homeassistant.helpers.service import async_register_admin_service from homeassistant.helpers.typing import ConfigType, HomeAssistantType, ServiceCallType from homeassistant.loader import async_get_integration -from homeassistant.util import sanitize_path from . import dashboard, resources, websocket from .const import ( @@ -47,7 +46,7 @@ YAML_DASHBOARD_SCHEMA = vol.Schema( { **DASHBOARD_BASE_CREATE_FIELDS, vol.Required(CONF_MODE): MODE_YAML, - vol.Required(CONF_FILENAME): vol.All(cv.string, sanitize_path), + vol.Required(CONF_FILENAME): cv.path, } ) diff --git a/homeassistant/components/media_source/local_source.py b/homeassistant/components/media_source/local_source.py index 6c60da562e0..d7a2bdfd938 100644 --- a/homeassistant/components/media_source/local_source.py +++ b/homeassistant/components/media_source/local_source.py @@ -10,7 +10,7 @@ from homeassistant.components.media_player.const import MEDIA_CLASS_DIRECTORY from homeassistant.components.media_player.errors import BrowseError from homeassistant.components.media_source.error import Unresolvable from homeassistant.core import HomeAssistant, callback -from homeassistant.util import sanitize_path +from homeassistant.util import raise_if_invalid_filename from .const import DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES from .models import BrowseMediaSource, MediaSource, MediaSourceItem, PlayMedia @@ -50,8 +50,10 @@ class LocalSource(MediaSource): if source_dir_id not in self.hass.config.media_dirs: raise Unresolvable("Unknown source directory.") - if location != sanitize_path(location): - raise Unresolvable("Invalid path.") + try: + raise_if_invalid_filename(location) + except ValueError as err: + raise Unresolvable("Invalid path.") from err return source_dir_id, location @@ -189,8 +191,10 @@ class LocalMediaView(HomeAssistantView): self, request: web.Request, source_dir_id: str, location: str ) -> web.FileResponse: """Start a GET request.""" - if location != sanitize_path(location): - raise web.HTTPNotFound() + try: + raise_if_invalid_filename(location) + except ValueError as err: + raise web.HTTPBadRequest() from err if source_dir_id not in self.hass.config.media_dirs: raise web.HTTPNotFound() diff --git a/homeassistant/components/python_script/__init__.py b/homeassistant/components/python_script/__init__.py index 6bd840074d9..bed159cae6b 100644 --- a/homeassistant/components/python_script/__init__.py +++ b/homeassistant/components/python_script/__init__.py @@ -23,7 +23,7 @@ from homeassistant.const import SERVICE_RELOAD from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.service import async_set_service_schema from homeassistant.loader import bind_hass -from homeassistant.util import sanitize_filename +from homeassistant.util import raise_if_invalid_filename import homeassistant.util.dt as dt_util from homeassistant.util.yaml.loader import load_yaml @@ -135,7 +135,8 @@ def discover_scripts(hass): def execute_script(hass, name, data=None): """Execute a script.""" filename = f"{name}.py" - with open(hass.config.path(FOLDER, sanitize_filename(filename))) as fil: + raise_if_invalid_filename(filename) + with open(hass.config.path(FOLDER, filename)) as fil: source = fil.read() execute(hass, filename, source, data) diff --git a/homeassistant/helpers/config_validation.py b/homeassistant/helpers/config_validation.py index 28f18cf9407..acf6139708a 100644 --- a/homeassistant/helpers/config_validation.py +++ b/homeassistant/helpers/config_validation.py @@ -87,7 +87,7 @@ from homeassistant.helpers import ( template as template_helper, ) from homeassistant.helpers.logging import KeywordStyleAdapter -from homeassistant.util import sanitize_path, slugify as util_slugify +from homeassistant.util import raise_if_invalid_path, slugify as util_slugify import homeassistant.util.dt as dt_util # pylint: disable=invalid-name @@ -118,8 +118,10 @@ def path(value: Any) -> str: if not isinstance(value, str): raise vol.Invalid("Expected a string") - if sanitize_path(value) != value: - raise vol.Invalid("Invalid path") + try: + raise_if_invalid_path(value) + except ValueError as err: + raise vol.Invalid("Invalid path") from err return value diff --git a/homeassistant/helpers/deprecation.py b/homeassistant/helpers/deprecation.py index a62a2e63804..0022f888829 100644 --- a/homeassistant/helpers/deprecation.py +++ b/homeassistant/helpers/deprecation.py @@ -1,4 +1,5 @@ """Deprecation helpers for Home Assistant.""" +import functools import inspect import logging from typing import Any, Callable, Dict, Optional @@ -73,3 +74,25 @@ def get_deprecated( ) return config.get(old_name) return config.get(new_name, default) + + +def deprecated_function(replacement: str) -> Callable[..., Callable]: + """Mark function as deprecated and provide a replacement function to be used instead.""" + + def deprecated_decorator(func: Callable) -> Callable: + """Decorate function as deprecated.""" + + @functools.wraps(func) + def deprecated_func(*args: tuple, **kwargs: Dict[str, Any]) -> Any: + """Wrap for the original function.""" + logger = logging.getLogger(func.__module__) + logger.warning( + "%s is a deprecated function. Use %s instead", + func.__name__, + replacement, + ) + return func(*args, **kwargs) + + return deprecated_func + + return deprecated_decorator diff --git a/homeassistant/util/__init__.py b/homeassistant/util/__init__.py index d3178cb5ddd..ad4ca18e4fe 100644 --- a/homeassistant/util/__init__.py +++ b/homeassistant/util/__init__.py @@ -22,6 +22,7 @@ from typing import ( import slugify as unicode_slug +from ..helpers.deprecation import deprecated_function from .dt import as_local, utcnow T = TypeVar("T") @@ -32,6 +33,27 @@ RE_SANITIZE_FILENAME = re.compile(r"(~|\.\.|/|\\)") RE_SANITIZE_PATH = re.compile(r"(~|\.(\.)+)") +def raise_if_invalid_filename(filename: str) -> None: + """ + Check if a filename is valid. + + Raises a ValueError if the filename is invalid. + """ + if RE_SANITIZE_FILENAME.sub("", filename) != filename: + raise ValueError(f"{filename} is not a safe filename") + + +def raise_if_invalid_path(path: str) -> None: + """ + Check if a path is valid. + + Raises a ValueError if the path is invalid. + """ + if RE_SANITIZE_PATH.sub("", path) != path: + raise ValueError(f"{path} is not a safe path") + + +@deprecated_function(replacement="raise_if_invalid_filename") def sanitize_filename(filename: str) -> str: """Check if a filename is safe. @@ -47,6 +69,7 @@ def sanitize_filename(filename: str) -> str: return filename +@deprecated_function(replacement="raise_if_invalid_path") def sanitize_path(path: str) -> str: """Check if a path is safe. diff --git a/tests/components/media_source/test_local_source.py b/tests/components/media_source/test_local_source.py index e3e2a3f1617..ad10df7cfd3 100644 --- a/tests/components/media_source/test_local_source.py +++ b/tests/components/media_source/test_local_source.py @@ -23,7 +23,7 @@ async def test_async_browse_media(hass): await media_source.async_browse_media( hass, f"{const.URI_SCHEME}{const.DOMAIN}/local/test/not/exist" ) - assert str(excinfo.value) == "Path does not exist." + assert str(excinfo.value) == "Invalid path." # Test browse file with pytest.raises(media_source.BrowseError) as excinfo: diff --git a/tests/helpers/test_deprecation.py b/tests/helpers/test_deprecation.py index 38410c3bf0f..ebabe12c1ad 100644 --- a/tests/helpers/test_deprecation.py +++ b/tests/helpers/test_deprecation.py @@ -1,7 +1,11 @@ """Test deprecation helpers.""" from unittest.mock import MagicMock, patch -from homeassistant.helpers.deprecation import deprecated_substitute, get_deprecated +from homeassistant.helpers.deprecation import ( + deprecated_function, + deprecated_substitute, + get_deprecated, +) class MockBaseClass: @@ -78,3 +82,17 @@ def test_config_get_deprecated_new(mock_get_logger): config = {"new_name": True} assert get_deprecated(config, "new_name", "old_name") is True assert not mock_logger.warning.called + + +def test_deprecated_function(caplog): + """Test deprecated_function decorator.""" + + @deprecated_function("new_function") + def mock_deprecated_function(): + pass + + mock_deprecated_function() + assert ( + "mock_deprecated_function is a deprecated function. Use new_function instead" + in caplog.text + ) diff --git a/tests/util/test_init.py b/tests/util/test_init.py index 8ba034b79da..3855b62deb6 100644 --- a/tests/util/test_init.py +++ b/tests/util/test_init.py @@ -24,6 +24,34 @@ def test_sanitize_path(): assert util.sanitize_path("~/../test/path") == "" +def test_raise_if_invalid_filename(): + """Test raise_if_invalid_filename.""" + assert util.raise_if_invalid_filename("test") is None + + with pytest.raises(ValueError): + util.raise_if_invalid_filename("/test") + + with pytest.raises(ValueError): + util.raise_if_invalid_filename("..test") + + with pytest.raises(ValueError): + util.raise_if_invalid_filename("\\test") + + with pytest.raises(ValueError): + util.raise_if_invalid_filename("\\../test") + + +def test_raise_if_invalid_path(): + """Test raise_if_invalid_path.""" + assert util.raise_if_invalid_path("test/path") is None + + with pytest.raises(ValueError): + assert util.raise_if_invalid_path("~test/path") + + with pytest.raises(ValueError): + assert util.raise_if_invalid_path("~/../test/path") + + def test_slugify(): """Test slugify.""" assert util.slugify("T-!@#$!#@$!$est") == "t_est"