Avoid path construction for static files cache hit (#102882)

pull/103142/head^2
J. Nick Koston 2023-10-31 14:31:58 -05:00 committed by GitHub
parent 4d475a9758
commit 8eb7766f30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 88 additions and 24 deletions

View File

@ -22,10 +22,15 @@ CACHE_HEADERS: Final[Mapping[str, str]] = {
PATH_CACHE = LRU(512)
def _get_file_path(
filename: str | Path, directory: Path, follow_symlinks: bool
) -> Path | None:
filepath = directory.joinpath(filename).resolve()
def _get_file_path(rel_url: str, directory: Path, follow_symlinks: bool) -> Path | None:
"""Return the path to file on disk or None."""
filename = Path(rel_url)
if filename.anchor:
# rel_url is an absolute name like
# /static/\\machine_name\c$ or /static/D:\path
# where the static dir is totally different
raise HTTPForbidden
filepath: Path = directory.joinpath(filename).resolve()
if not follow_symlinks:
filepath.relative_to(directory)
# on opening a dir, load its contents if allowed
@ -40,27 +45,24 @@ class CachingStaticResource(StaticResource):
"""Static Resource handler that will add cache headers."""
async def _handle(self, request: Request) -> StreamResponse:
"""Return requested file from disk as a FileResponse."""
rel_url = request.match_info["filename"]
hass: HomeAssistant = request.app[KEY_HASS]
filename = Path(rel_url)
if filename.anchor:
# rel_url is an absolute name like
# /static/\\machine_name\c$ or /static/D:\path
# where the static dir is totally different
raise HTTPForbidden()
try:
key = (filename, self._directory, self._follow_symlinks)
if (filepath := PATH_CACHE.get(key)) is None:
filepath = PATH_CACHE[key] = await hass.async_add_executor_job(
_get_file_path, filename, self._directory, self._follow_symlinks
)
except (ValueError, FileNotFoundError) as error:
# relatively safe
raise HTTPNotFound() from error
except Exception as error:
# perm error or other kind!
request.app.logger.exception(error)
raise HTTPNotFound() from error
key = (rel_url, self._directory, self._follow_symlinks)
if (filepath := PATH_CACHE.get(key)) is None:
hass: HomeAssistant = request.app[KEY_HASS]
try:
filepath = await hass.async_add_executor_job(_get_file_path, *key)
except (ValueError, FileNotFoundError) as error:
# relatively safe
raise HTTPNotFound() from error
except HTTPForbidden:
# forbidden
raise
except Exception as error:
# perm error or other kind!
request.app.logger.exception(error)
raise HTTPNotFound() from error
PATH_CACHE[key] = filepath
if filepath:
return FileResponse(
@ -68,4 +70,5 @@ class CachingStaticResource(StaticResource):
chunk_size=self._chunk_size,
headers=CACHE_HEADERS,
)
return await super()._handle(request)

View File

@ -0,0 +1,61 @@
"""The tests for http static files."""
from pathlib import Path
from aiohttp.test_utils import TestClient
from aiohttp.web_exceptions import HTTPForbidden
import pytest
from homeassistant.components.http.static import CachingStaticResource, _get_file_path
from homeassistant.core import EVENT_HOMEASSISTANT_START, HomeAssistant
from homeassistant.setup import async_setup_component
from tests.typing import ClientSessionGenerator
@pytest.fixture(autouse=True)
async def http(hass: HomeAssistant) -> None:
"""Ensure http is set up."""
assert await async_setup_component(hass, "http", {})
hass.bus.async_fire(EVENT_HOMEASSISTANT_START)
await hass.async_block_till_done()
@pytest.fixture
async def mock_http_client(hass: HomeAssistant, aiohttp_client: ClientSessionGenerator):
"""Start the Home Assistant HTTP component."""
return await aiohttp_client(hass.http.app, server_kwargs={"skip_url_asserts": True})
@pytest.mark.parametrize(
("url", "canonical_url"),
(
("//a", "//a"),
("///a", "///a"),
("/c:\\a\\b", "/c:%5Ca%5Cb"),
),
)
async def test_static_path_blocks_anchors(
hass: HomeAssistant,
mock_http_client: TestClient,
tmp_path: Path,
url: str,
canonical_url: str,
) -> None:
"""Test static paths block anchors."""
app = hass.http.app
resource = CachingStaticResource(url, str(tmp_path))
assert resource.canonical == canonical_url
app.router.register_resource(resource)
app["allow_configured_cors"](resource)
resp = await mock_http_client.get(canonical_url, allow_redirects=False)
assert resp.status == 403
# Tested directly since aiohttp will block it before
# it gets here but we want to make sure if aiohttp ever
# changes we still block it.
with pytest.raises(HTTPForbidden):
_get_file_path(canonical_url, tmp_path, False)