Allow a fixed number of ffmpeg proxy conversions per device (#129246)

Allow a fixed number of conversions per device
pull/129358/head
Michael Hansen 2024-10-28 15:26:43 -05:00 committed by GitHub
parent 73f2d972e4
commit dd9ce34d18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 97 additions and 10 deletions

View File

@ -1,10 +1,12 @@
"""HTTP view that converts audio from a URL to a preferred format."""
import asyncio
from collections import defaultdict
from dataclasses import dataclass, field
from http import HTTPStatus
import logging
import secrets
from typing import Final
from aiohttp import web
from aiohttp.abc import AbstractStreamWriter, BaseRequest
@ -17,6 +19,8 @@ from .const import DATA_FFMPEG_PROXY
_LOGGER = logging.getLogger(__name__)
_MAX_CONVERSIONS_PER_DEVICE: Final[int] = 2
def async_create_proxy_url(
hass: HomeAssistant,
@ -59,13 +63,18 @@ class FFmpegConversionInfo:
proc: asyncio.subprocess.Process | None = None
"""Subprocess doing ffmpeg conversion."""
is_finished: bool = False
"""True if conversion has finished."""
@dataclass
class FFmpegProxyData:
"""Data for ffmpeg proxy conversion."""
# device_id -> info
conversions: dict[str, FFmpegConversionInfo] = field(default_factory=dict)
# device_id -> [info]
conversions: dict[str, list[FFmpegConversionInfo]] = field(
default_factory=lambda: defaultdict(list)
)
def async_create_proxy_url(
self,
@ -77,8 +86,15 @@ class FFmpegProxyData:
width: int | None,
) -> str:
"""Create a one-time use proxy URL that automatically converts the media."""
if (convert_info := self.conversions.pop(device_id, None)) is not None:
# Stop existing conversion before overwriting info
# Remove completed conversions
device_conversions = [
info for info in self.conversions[device_id] if not info.is_finished
]
while len(device_conversions) >= _MAX_CONVERSIONS_PER_DEVICE:
# Stop oldest conversion before adding a new one
convert_info = device_conversions[0]
if (convert_info.proc is not None) and (
convert_info.proc.returncode is None
):
@ -87,12 +103,18 @@ class FFmpegProxyData:
)
convert_info.proc.kill()
device_conversions = device_conversions[1:]
convert_id = secrets.token_urlsafe(16)
self.conversions[device_id] = FFmpegConversionInfo(
convert_id, media_url, media_format, rate, channels, width
device_conversions.append(
FFmpegConversionInfo(
convert_id, media_url, media_format, rate, channels, width
)
)
_LOGGER.debug("Media URL allowed by proxy: %s", media_url)
self.conversions[device_id] = device_conversions
return f"/api/esphome/ffmpeg_proxy/{device_id}/{convert_id}.{media_format}"
@ -167,6 +189,7 @@ class FFmpegConvertResponse(web.StreamResponse):
*command_args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
close_fds=False, # use posix_spawn in CPython < 3.13
)
# Only one conversion process per device is allowed
@ -198,6 +221,9 @@ class FFmpegConvertResponse(web.StreamResponse):
raise
finally:
# Allow conversion info to be removed
self.convert_info.is_finished = True
# Terminate hangs, so kill is used
if proc.returncode is None:
proc.kill()
@ -224,7 +250,8 @@ class FFmpegProxyView(HomeAssistantView):
self, request: web.Request, device_id: str, filename: str
) -> web.StreamResponse:
"""Start a get request."""
if (convert_info := self.proxy_data.conversions.get(device_id)) is None:
device_conversions = self.proxy_data.conversions[device_id]
if not device_conversions:
return web.Response(
body="No proxy URL for device", status=HTTPStatus.NOT_FOUND
)
@ -232,9 +259,16 @@ class FFmpegProxyView(HomeAssistantView):
# {id}.mp3 -> id, mp3
convert_id, media_format = filename.rsplit(".")
if (convert_info.convert_id != convert_id) or (
convert_info.media_format != media_format
):
# Look up conversion info
convert_info: FFmpegConversionInfo | None = None
for maybe_convert_info in device_conversions:
if (maybe_convert_info.convert_id == convert_id) and (
maybe_convert_info.media_format == media_format
):
convert_info = maybe_convert_info
break
if convert_info is None:
return web.Response(body="Invalid proxy URL", status=HTTPStatus.BAD_REQUEST)
# Stop previous process if the URL is being reused.

View File

@ -2,6 +2,7 @@
from http import HTTPStatus
import io
import os
import tempfile
from unittest.mock import patch
from urllib.request import pathname2url
@ -232,3 +233,55 @@ async def test_request_same_url_multiple_times(
num_frames += len(chunk) // (2 * 2) # 2 channels, 16-bit samples
assert num_frames == 22050 * 10 # 10s
async def test_max_conversions_per_device(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
) -> None:
"""Test that each device has a maximum number of conversions (currently 2)."""
max_conversions = 2
device_ids = ["1234", "5678"]
await async_setup_component(hass, esphome.DOMAIN, {esphome.DOMAIN: {}})
client = await hass_client()
with tempfile.TemporaryDirectory() as temp_dir:
wav_paths = [
os.path.join(temp_dir, f"{i}.wav") for i in range(max_conversions + 1)
]
for wav_path in wav_paths:
with wave.open(wav_path, "wb") as wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(16000 * 2 * 10)) # 10s
wav_urls = [pathname2url(p) for p in wav_paths]
# Each device will have max + 1 conversions
device_urls = {
device_id: [
async_create_proxy_url(
hass,
device_id,
wav_url,
media_format="wav",
rate=22050,
channels=2,
width=2,
)
for wav_url in wav_urls
]
for device_id in device_ids
}
for urls in device_urls.values():
# First URL should fail because it was overwritten by the others
req = await client.get(urls[0])
assert req.status == HTTPStatus.BAD_REQUEST
# All other URLs should succeed
for url in urls[1:]:
req = await client.get(url)
assert req.status == HTTPStatus.OK