Changes to filename and path validation (#45529)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
pull/45571/head
Joakim Sørensen 2021-01-26 15:53:21 +01:00 committed by GitHub
parent 4739e8a207
commit b1c2cde40b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 127 additions and 19 deletions

View File

@ -9,7 +9,7 @@ import voluptuous as vol
from homeassistant.const import HTTP_OK
import homeassistant.helpers.config_validation as cv
from homeassistant.util import sanitize_filename
from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path
_LOGGER = logging.getLogger(__name__)
@ -70,8 +70,8 @@ def setup(hass, config):
overwrite = service.data.get(ATTR_OVERWRITE)
if subdir:
subdir = sanitize_filename(subdir)
# Check the path
raise_if_invalid_path(subdir)
final_path = None
@ -101,8 +101,8 @@ def setup(hass, config):
if not filename:
filename = "ha_download"
# Remove stuff to ruin paths
filename = sanitize_filename(filename)
# Check the filename
raise_if_invalid_filename(filename)
# Do we want to download to subdir, create if needed
if subdir:
@ -148,6 +148,16 @@ def setup(hass, config):
{"url": url, "filename": filename},
)
# Remove file if we started downloading but failed
if final_path and os.path.isfile(final_path):
os.remove(final_path)
except ValueError:
_LOGGER.exception("Invalid value")
hass.bus.fire(
f"{DOMAIN}_{DOWNLOAD_FAILED_EVENT}",
{"url": url, "filename": filename},
)
# Remove file if we started downloading but failed
if final_path and os.path.isfile(final_path):
os.remove(final_path)

View File

@ -12,7 +12,6 @@ from homeassistant.helpers import collection, config_validation as cv
from homeassistant.helpers.service import async_register_admin_service
from homeassistant.helpers.typing import ConfigType, HomeAssistantType, ServiceCallType
from homeassistant.loader import async_get_integration
from homeassistant.util import sanitize_path
from . import dashboard, resources, websocket
from .const import (
@ -47,7 +46,7 @@ YAML_DASHBOARD_SCHEMA = vol.Schema(
{
**DASHBOARD_BASE_CREATE_FIELDS,
vol.Required(CONF_MODE): MODE_YAML,
vol.Required(CONF_FILENAME): vol.All(cv.string, sanitize_path),
vol.Required(CONF_FILENAME): cv.path,
}
)

View File

@ -10,7 +10,7 @@ from homeassistant.components.media_player.const import MEDIA_CLASS_DIRECTORY
from homeassistant.components.media_player.errors import BrowseError
from homeassistant.components.media_source.error import Unresolvable
from homeassistant.core import HomeAssistant, callback
from homeassistant.util import sanitize_path
from homeassistant.util import raise_if_invalid_filename
from .const import DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES
from .models import BrowseMediaSource, MediaSource, MediaSourceItem, PlayMedia
@ -50,8 +50,10 @@ class LocalSource(MediaSource):
if source_dir_id not in self.hass.config.media_dirs:
raise Unresolvable("Unknown source directory.")
if location != sanitize_path(location):
raise Unresolvable("Invalid path.")
try:
raise_if_invalid_filename(location)
except ValueError as err:
raise Unresolvable("Invalid path.") from err
return source_dir_id, location
@ -189,8 +191,10 @@ class LocalMediaView(HomeAssistantView):
self, request: web.Request, source_dir_id: str, location: str
) -> web.FileResponse:
"""Start a GET request."""
if location != sanitize_path(location):
raise web.HTTPNotFound()
try:
raise_if_invalid_filename(location)
except ValueError as err:
raise web.HTTPBadRequest() from err
if source_dir_id not in self.hass.config.media_dirs:
raise web.HTTPNotFound()

View File

@ -23,7 +23,7 @@ from homeassistant.const import SERVICE_RELOAD
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.service import async_set_service_schema
from homeassistant.loader import bind_hass
from homeassistant.util import sanitize_filename
from homeassistant.util import raise_if_invalid_filename
import homeassistant.util.dt as dt_util
from homeassistant.util.yaml.loader import load_yaml
@ -135,7 +135,8 @@ def discover_scripts(hass):
def execute_script(hass, name, data=None):
"""Execute a script."""
filename = f"{name}.py"
with open(hass.config.path(FOLDER, sanitize_filename(filename))) as fil:
raise_if_invalid_filename(filename)
with open(hass.config.path(FOLDER, filename)) as fil:
source = fil.read()
execute(hass, filename, source, data)

View File

@ -87,7 +87,7 @@ from homeassistant.helpers import (
template as template_helper,
)
from homeassistant.helpers.logging import KeywordStyleAdapter
from homeassistant.util import sanitize_path, slugify as util_slugify
from homeassistant.util import raise_if_invalid_path, slugify as util_slugify
import homeassistant.util.dt as dt_util
# pylint: disable=invalid-name
@ -118,8 +118,10 @@ def path(value: Any) -> str:
if not isinstance(value, str):
raise vol.Invalid("Expected a string")
if sanitize_path(value) != value:
raise vol.Invalid("Invalid path")
try:
raise_if_invalid_path(value)
except ValueError as err:
raise vol.Invalid("Invalid path") from err
return value

View File

@ -1,4 +1,5 @@
"""Deprecation helpers for Home Assistant."""
import functools
import inspect
import logging
from typing import Any, Callable, Dict, Optional
@ -73,3 +74,25 @@ def get_deprecated(
)
return config.get(old_name)
return config.get(new_name, default)
def deprecated_function(replacement: str) -> Callable[..., Callable]:
"""Mark function as deprecated and provide a replacement function to be used instead."""
def deprecated_decorator(func: Callable) -> Callable:
"""Decorate function as deprecated."""
@functools.wraps(func)
def deprecated_func(*args: tuple, **kwargs: Dict[str, Any]) -> Any:
"""Wrap for the original function."""
logger = logging.getLogger(func.__module__)
logger.warning(
"%s is a deprecated function. Use %s instead",
func.__name__,
replacement,
)
return func(*args, **kwargs)
return deprecated_func
return deprecated_decorator

View File

@ -22,6 +22,7 @@ from typing import (
import slugify as unicode_slug
from ..helpers.deprecation import deprecated_function
from .dt import as_local, utcnow
T = TypeVar("T")
@ -32,6 +33,27 @@ RE_SANITIZE_FILENAME = re.compile(r"(~|\.\.|/|\\)")
RE_SANITIZE_PATH = re.compile(r"(~|\.(\.)+)")
def raise_if_invalid_filename(filename: str) -> None:
"""
Check if a filename is valid.
Raises a ValueError if the filename is invalid.
"""
if RE_SANITIZE_FILENAME.sub("", filename) != filename:
raise ValueError(f"{filename} is not a safe filename")
def raise_if_invalid_path(path: str) -> None:
"""
Check if a path is valid.
Raises a ValueError if the path is invalid.
"""
if RE_SANITIZE_PATH.sub("", path) != path:
raise ValueError(f"{path} is not a safe path")
@deprecated_function(replacement="raise_if_invalid_filename")
def sanitize_filename(filename: str) -> str:
"""Check if a filename is safe.
@ -47,6 +69,7 @@ def sanitize_filename(filename: str) -> str:
return filename
@deprecated_function(replacement="raise_if_invalid_path")
def sanitize_path(path: str) -> str:
"""Check if a path is safe.

View File

@ -23,7 +23,7 @@ async def test_async_browse_media(hass):
await media_source.async_browse_media(
hass, f"{const.URI_SCHEME}{const.DOMAIN}/local/test/not/exist"
)
assert str(excinfo.value) == "Path does not exist."
assert str(excinfo.value) == "Invalid path."
# Test browse file
with pytest.raises(media_source.BrowseError) as excinfo:

View File

@ -1,7 +1,11 @@
"""Test deprecation helpers."""
from unittest.mock import MagicMock, patch
from homeassistant.helpers.deprecation import deprecated_substitute, get_deprecated
from homeassistant.helpers.deprecation import (
deprecated_function,
deprecated_substitute,
get_deprecated,
)
class MockBaseClass:
@ -78,3 +82,17 @@ def test_config_get_deprecated_new(mock_get_logger):
config = {"new_name": True}
assert get_deprecated(config, "new_name", "old_name") is True
assert not mock_logger.warning.called
def test_deprecated_function(caplog):
"""Test deprecated_function decorator."""
@deprecated_function("new_function")
def mock_deprecated_function():
pass
mock_deprecated_function()
assert (
"mock_deprecated_function is a deprecated function. Use new_function instead"
in caplog.text
)

View File

@ -24,6 +24,34 @@ def test_sanitize_path():
assert util.sanitize_path("~/../test/path") == ""
def test_raise_if_invalid_filename():
"""Test raise_if_invalid_filename."""
assert util.raise_if_invalid_filename("test") is None
with pytest.raises(ValueError):
util.raise_if_invalid_filename("/test")
with pytest.raises(ValueError):
util.raise_if_invalid_filename("..test")
with pytest.raises(ValueError):
util.raise_if_invalid_filename("\\test")
with pytest.raises(ValueError):
util.raise_if_invalid_filename("\\../test")
def test_raise_if_invalid_path():
"""Test raise_if_invalid_path."""
assert util.raise_if_invalid_path("test/path") is None
with pytest.raises(ValueError):
assert util.raise_if_invalid_path("~test/path")
with pytest.raises(ValueError):
assert util.raise_if_invalid_path("~/../test/path")
def test_slugify():
"""Test slugify."""
assert util.slugify("T-!@#$!#@$!$est") == "t_est"