Changes to filename and path validation (#45529)
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>pull/45571/head
parent
4739e8a207
commit
b1c2cde40b
|
@ -9,7 +9,7 @@ import voluptuous as vol
|
|||
|
||||
from homeassistant.const import HTTP_OK
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.util import sanitize_filename
|
||||
from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -70,8 +70,8 @@ def setup(hass, config):
|
|||
|
||||
overwrite = service.data.get(ATTR_OVERWRITE)
|
||||
|
||||
if subdir:
|
||||
subdir = sanitize_filename(subdir)
|
||||
# Check the path
|
||||
raise_if_invalid_path(subdir)
|
||||
|
||||
final_path = None
|
||||
|
||||
|
@ -101,8 +101,8 @@ def setup(hass, config):
|
|||
if not filename:
|
||||
filename = "ha_download"
|
||||
|
||||
# Remove stuff to ruin paths
|
||||
filename = sanitize_filename(filename)
|
||||
# Check the filename
|
||||
raise_if_invalid_filename(filename)
|
||||
|
||||
# Do we want to download to subdir, create if needed
|
||||
if subdir:
|
||||
|
@ -148,6 +148,16 @@ def setup(hass, config):
|
|||
{"url": url, "filename": filename},
|
||||
)
|
||||
|
||||
# Remove file if we started downloading but failed
|
||||
if final_path and os.path.isfile(final_path):
|
||||
os.remove(final_path)
|
||||
except ValueError:
|
||||
_LOGGER.exception("Invalid value")
|
||||
hass.bus.fire(
|
||||
f"{DOMAIN}_{DOWNLOAD_FAILED_EVENT}",
|
||||
{"url": url, "filename": filename},
|
||||
)
|
||||
|
||||
# Remove file if we started downloading but failed
|
||||
if final_path and os.path.isfile(final_path):
|
||||
os.remove(final_path)
|
||||
|
|
|
@ -12,7 +12,6 @@ from homeassistant.helpers import collection, config_validation as cv
|
|||
from homeassistant.helpers.service import async_register_admin_service
|
||||
from homeassistant.helpers.typing import ConfigType, HomeAssistantType, ServiceCallType
|
||||
from homeassistant.loader import async_get_integration
|
||||
from homeassistant.util import sanitize_path
|
||||
|
||||
from . import dashboard, resources, websocket
|
||||
from .const import (
|
||||
|
@ -47,7 +46,7 @@ YAML_DASHBOARD_SCHEMA = vol.Schema(
|
|||
{
|
||||
**DASHBOARD_BASE_CREATE_FIELDS,
|
||||
vol.Required(CONF_MODE): MODE_YAML,
|
||||
vol.Required(CONF_FILENAME): vol.All(cv.string, sanitize_path),
|
||||
vol.Required(CONF_FILENAME): cv.path,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ from homeassistant.components.media_player.const import MEDIA_CLASS_DIRECTORY
|
|||
from homeassistant.components.media_player.errors import BrowseError
|
||||
from homeassistant.components.media_source.error import Unresolvable
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.util import sanitize_path
|
||||
from homeassistant.util import raise_if_invalid_filename
|
||||
|
||||
from .const import DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES
|
||||
from .models import BrowseMediaSource, MediaSource, MediaSourceItem, PlayMedia
|
||||
|
@ -50,8 +50,10 @@ class LocalSource(MediaSource):
|
|||
if source_dir_id not in self.hass.config.media_dirs:
|
||||
raise Unresolvable("Unknown source directory.")
|
||||
|
||||
if location != sanitize_path(location):
|
||||
raise Unresolvable("Invalid path.")
|
||||
try:
|
||||
raise_if_invalid_filename(location)
|
||||
except ValueError as err:
|
||||
raise Unresolvable("Invalid path.") from err
|
||||
|
||||
return source_dir_id, location
|
||||
|
||||
|
@ -189,8 +191,10 @@ class LocalMediaView(HomeAssistantView):
|
|||
self, request: web.Request, source_dir_id: str, location: str
|
||||
) -> web.FileResponse:
|
||||
"""Start a GET request."""
|
||||
if location != sanitize_path(location):
|
||||
raise web.HTTPNotFound()
|
||||
try:
|
||||
raise_if_invalid_filename(location)
|
||||
except ValueError as err:
|
||||
raise web.HTTPBadRequest() from err
|
||||
|
||||
if source_dir_id not in self.hass.config.media_dirs:
|
||||
raise web.HTTPNotFound()
|
||||
|
|
|
@ -23,7 +23,7 @@ from homeassistant.const import SERVICE_RELOAD
|
|||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.service import async_set_service_schema
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.util import sanitize_filename
|
||||
from homeassistant.util import raise_if_invalid_filename
|
||||
import homeassistant.util.dt as dt_util
|
||||
from homeassistant.util.yaml.loader import load_yaml
|
||||
|
||||
|
@ -135,7 +135,8 @@ def discover_scripts(hass):
|
|||
def execute_script(hass, name, data=None):
|
||||
"""Execute a script."""
|
||||
filename = f"{name}.py"
|
||||
with open(hass.config.path(FOLDER, sanitize_filename(filename))) as fil:
|
||||
raise_if_invalid_filename(filename)
|
||||
with open(hass.config.path(FOLDER, filename)) as fil:
|
||||
source = fil.read()
|
||||
execute(hass, filename, source, data)
|
||||
|
||||
|
|
|
@ -87,7 +87,7 @@ from homeassistant.helpers import (
|
|||
template as template_helper,
|
||||
)
|
||||
from homeassistant.helpers.logging import KeywordStyleAdapter
|
||||
from homeassistant.util import sanitize_path, slugify as util_slugify
|
||||
from homeassistant.util import raise_if_invalid_path, slugify as util_slugify
|
||||
import homeassistant.util.dt as dt_util
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
@ -118,8 +118,10 @@ def path(value: Any) -> str:
|
|||
if not isinstance(value, str):
|
||||
raise vol.Invalid("Expected a string")
|
||||
|
||||
if sanitize_path(value) != value:
|
||||
raise vol.Invalid("Invalid path")
|
||||
try:
|
||||
raise_if_invalid_path(value)
|
||||
except ValueError as err:
|
||||
raise vol.Invalid("Invalid path") from err
|
||||
|
||||
return value
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Deprecation helpers for Home Assistant."""
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
@ -73,3 +74,25 @@ def get_deprecated(
|
|||
)
|
||||
return config.get(old_name)
|
||||
return config.get(new_name, default)
|
||||
|
||||
|
||||
def deprecated_function(replacement: str) -> Callable[..., Callable]:
|
||||
"""Mark function as deprecated and provide a replacement function to be used instead."""
|
||||
|
||||
def deprecated_decorator(func: Callable) -> Callable:
|
||||
"""Decorate function as deprecated."""
|
||||
|
||||
@functools.wraps(func)
|
||||
def deprecated_func(*args: tuple, **kwargs: Dict[str, Any]) -> Any:
|
||||
"""Wrap for the original function."""
|
||||
logger = logging.getLogger(func.__module__)
|
||||
logger.warning(
|
||||
"%s is a deprecated function. Use %s instead",
|
||||
func.__name__,
|
||||
replacement,
|
||||
)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return deprecated_func
|
||||
|
||||
return deprecated_decorator
|
||||
|
|
|
@ -22,6 +22,7 @@ from typing import (
|
|||
|
||||
import slugify as unicode_slug
|
||||
|
||||
from ..helpers.deprecation import deprecated_function
|
||||
from .dt import as_local, utcnow
|
||||
|
||||
T = TypeVar("T")
|
||||
|
@ -32,6 +33,27 @@ RE_SANITIZE_FILENAME = re.compile(r"(~|\.\.|/|\\)")
|
|||
RE_SANITIZE_PATH = re.compile(r"(~|\.(\.)+)")
|
||||
|
||||
|
||||
def raise_if_invalid_filename(filename: str) -> None:
|
||||
"""
|
||||
Check if a filename is valid.
|
||||
|
||||
Raises a ValueError if the filename is invalid.
|
||||
"""
|
||||
if RE_SANITIZE_FILENAME.sub("", filename) != filename:
|
||||
raise ValueError(f"{filename} is not a safe filename")
|
||||
|
||||
|
||||
def raise_if_invalid_path(path: str) -> None:
|
||||
"""
|
||||
Check if a path is valid.
|
||||
|
||||
Raises a ValueError if the path is invalid.
|
||||
"""
|
||||
if RE_SANITIZE_PATH.sub("", path) != path:
|
||||
raise ValueError(f"{path} is not a safe path")
|
||||
|
||||
|
||||
@deprecated_function(replacement="raise_if_invalid_filename")
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
"""Check if a filename is safe.
|
||||
|
||||
|
@ -47,6 +69,7 @@ def sanitize_filename(filename: str) -> str:
|
|||
return filename
|
||||
|
||||
|
||||
@deprecated_function(replacement="raise_if_invalid_path")
|
||||
def sanitize_path(path: str) -> str:
|
||||
"""Check if a path is safe.
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ async def test_async_browse_media(hass):
|
|||
await media_source.async_browse_media(
|
||||
hass, f"{const.URI_SCHEME}{const.DOMAIN}/local/test/not/exist"
|
||||
)
|
||||
assert str(excinfo.value) == "Path does not exist."
|
||||
assert str(excinfo.value) == "Invalid path."
|
||||
|
||||
# Test browse file
|
||||
with pytest.raises(media_source.BrowseError) as excinfo:
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
"""Test deprecation helpers."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from homeassistant.helpers.deprecation import deprecated_substitute, get_deprecated
|
||||
from homeassistant.helpers.deprecation import (
|
||||
deprecated_function,
|
||||
deprecated_substitute,
|
||||
get_deprecated,
|
||||
)
|
||||
|
||||
|
||||
class MockBaseClass:
|
||||
|
@ -78,3 +82,17 @@ def test_config_get_deprecated_new(mock_get_logger):
|
|||
config = {"new_name": True}
|
||||
assert get_deprecated(config, "new_name", "old_name") is True
|
||||
assert not mock_logger.warning.called
|
||||
|
||||
|
||||
def test_deprecated_function(caplog):
|
||||
"""Test deprecated_function decorator."""
|
||||
|
||||
@deprecated_function("new_function")
|
||||
def mock_deprecated_function():
|
||||
pass
|
||||
|
||||
mock_deprecated_function()
|
||||
assert (
|
||||
"mock_deprecated_function is a deprecated function. Use new_function instead"
|
||||
in caplog.text
|
||||
)
|
||||
|
|
|
@ -24,6 +24,34 @@ def test_sanitize_path():
|
|||
assert util.sanitize_path("~/../test/path") == ""
|
||||
|
||||
|
||||
def test_raise_if_invalid_filename():
|
||||
"""Test raise_if_invalid_filename."""
|
||||
assert util.raise_if_invalid_filename("test") is None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
util.raise_if_invalid_filename("/test")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
util.raise_if_invalid_filename("..test")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
util.raise_if_invalid_filename("\\test")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
util.raise_if_invalid_filename("\\../test")
|
||||
|
||||
|
||||
def test_raise_if_invalid_path():
|
||||
"""Test raise_if_invalid_path."""
|
||||
assert util.raise_if_invalid_path("test/path") is None
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert util.raise_if_invalid_path("~test/path")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
assert util.raise_if_invalid_path("~/../test/path")
|
||||
|
||||
|
||||
def test_slugify():
|
||||
"""Test slugify."""
|
||||
assert util.slugify("T-!@#$!#@$!$est") == "t_est"
|
||||
|
|
Loading…
Reference in New Issue