Allow a fixed number of ffmpeg proxy conversions per device (#129246)
Allow a fixed number of conversions per devicepull/129358/head
parent
73f2d972e4
commit
dd9ce34d18
|
@ -1,10 +1,12 @@
|
|||
"""HTTP view that converts audio from a URL to a preferred format."""
|
||||
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from http import HTTPStatus
|
||||
import logging
|
||||
import secrets
|
||||
from typing import Final
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.abc import AbstractStreamWriter, BaseRequest
|
||||
|
@ -17,6 +19,8 @@ from .const import DATA_FFMPEG_PROXY
|
|||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
_MAX_CONVERSIONS_PER_DEVICE: Final[int] = 2
|
||||
|
||||
|
||||
def async_create_proxy_url(
|
||||
hass: HomeAssistant,
|
||||
|
@ -59,13 +63,18 @@ class FFmpegConversionInfo:
|
|||
proc: asyncio.subprocess.Process | None = None
|
||||
"""Subprocess doing ffmpeg conversion."""
|
||||
|
||||
is_finished: bool = False
|
||||
"""True if conversion has finished."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class FFmpegProxyData:
|
||||
"""Data for ffmpeg proxy conversion."""
|
||||
|
||||
# device_id -> info
|
||||
conversions: dict[str, FFmpegConversionInfo] = field(default_factory=dict)
|
||||
# device_id -> [info]
|
||||
conversions: dict[str, list[FFmpegConversionInfo]] = field(
|
||||
default_factory=lambda: defaultdict(list)
|
||||
)
|
||||
|
||||
def async_create_proxy_url(
|
||||
self,
|
||||
|
@ -77,8 +86,15 @@ class FFmpegProxyData:
|
|||
width: int | None,
|
||||
) -> str:
|
||||
"""Create a one-time use proxy URL that automatically converts the media."""
|
||||
if (convert_info := self.conversions.pop(device_id, None)) is not None:
|
||||
# Stop existing conversion before overwriting info
|
||||
|
||||
# Remove completed conversions
|
||||
device_conversions = [
|
||||
info for info in self.conversions[device_id] if not info.is_finished
|
||||
]
|
||||
|
||||
while len(device_conversions) >= _MAX_CONVERSIONS_PER_DEVICE:
|
||||
# Stop oldest conversion before adding a new one
|
||||
convert_info = device_conversions[0]
|
||||
if (convert_info.proc is not None) and (
|
||||
convert_info.proc.returncode is None
|
||||
):
|
||||
|
@ -87,12 +103,18 @@ class FFmpegProxyData:
|
|||
)
|
||||
convert_info.proc.kill()
|
||||
|
||||
device_conversions = device_conversions[1:]
|
||||
|
||||
convert_id = secrets.token_urlsafe(16)
|
||||
self.conversions[device_id] = FFmpegConversionInfo(
|
||||
convert_id, media_url, media_format, rate, channels, width
|
||||
device_conversions.append(
|
||||
FFmpegConversionInfo(
|
||||
convert_id, media_url, media_format, rate, channels, width
|
||||
)
|
||||
)
|
||||
_LOGGER.debug("Media URL allowed by proxy: %s", media_url)
|
||||
|
||||
self.conversions[device_id] = device_conversions
|
||||
|
||||
return f"/api/esphome/ffmpeg_proxy/{device_id}/{convert_id}.{media_format}"
|
||||
|
||||
|
||||
|
@ -167,6 +189,7 @@ class FFmpegConvertResponse(web.StreamResponse):
|
|||
*command_args,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
close_fds=False, # use posix_spawn in CPython < 3.13
|
||||
)
|
||||
|
||||
# Only one conversion process per device is allowed
|
||||
|
@ -198,6 +221,9 @@ class FFmpegConvertResponse(web.StreamResponse):
|
|||
|
||||
raise
|
||||
finally:
|
||||
# Allow conversion info to be removed
|
||||
self.convert_info.is_finished = True
|
||||
|
||||
# Terminate hangs, so kill is used
|
||||
if proc.returncode is None:
|
||||
proc.kill()
|
||||
|
@ -224,7 +250,8 @@ class FFmpegProxyView(HomeAssistantView):
|
|||
self, request: web.Request, device_id: str, filename: str
|
||||
) -> web.StreamResponse:
|
||||
"""Start a get request."""
|
||||
if (convert_info := self.proxy_data.conversions.get(device_id)) is None:
|
||||
device_conversions = self.proxy_data.conversions[device_id]
|
||||
if not device_conversions:
|
||||
return web.Response(
|
||||
body="No proxy URL for device", status=HTTPStatus.NOT_FOUND
|
||||
)
|
||||
|
@ -232,9 +259,16 @@ class FFmpegProxyView(HomeAssistantView):
|
|||
# {id}.mp3 -> id, mp3
|
||||
convert_id, media_format = filename.rsplit(".")
|
||||
|
||||
if (convert_info.convert_id != convert_id) or (
|
||||
convert_info.media_format != media_format
|
||||
):
|
||||
# Look up conversion info
|
||||
convert_info: FFmpegConversionInfo | None = None
|
||||
for maybe_convert_info in device_conversions:
|
||||
if (maybe_convert_info.convert_id == convert_id) and (
|
||||
maybe_convert_info.media_format == media_format
|
||||
):
|
||||
convert_info = maybe_convert_info
|
||||
break
|
||||
|
||||
if convert_info is None:
|
||||
return web.Response(body="Invalid proxy URL", status=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
# Stop previous process if the URL is being reused.
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from http import HTTPStatus
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import patch
|
||||
from urllib.request import pathname2url
|
||||
|
@ -232,3 +233,55 @@ async def test_request_same_url_multiple_times(
|
|||
num_frames += len(chunk) // (2 * 2) # 2 channels, 16-bit samples
|
||||
|
||||
assert num_frames == 22050 * 10 # 10s
|
||||
|
||||
|
||||
async def test_max_conversions_per_device(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
) -> None:
|
||||
"""Test that each device has a maximum number of conversions (currently 2)."""
|
||||
max_conversions = 2
|
||||
device_ids = ["1234", "5678"]
|
||||
|
||||
await async_setup_component(hass, esphome.DOMAIN, {esphome.DOMAIN: {}})
|
||||
client = await hass_client()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
wav_paths = [
|
||||
os.path.join(temp_dir, f"{i}.wav") for i in range(max_conversions + 1)
|
||||
]
|
||||
for wav_path in wav_paths:
|
||||
with wave.open(wav_path, "wb") as wav_file:
|
||||
wav_file.setframerate(16000)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.writeframes(bytes(16000 * 2 * 10)) # 10s
|
||||
|
||||
wav_urls = [pathname2url(p) for p in wav_paths]
|
||||
|
||||
# Each device will have max + 1 conversions
|
||||
device_urls = {
|
||||
device_id: [
|
||||
async_create_proxy_url(
|
||||
hass,
|
||||
device_id,
|
||||
wav_url,
|
||||
media_format="wav",
|
||||
rate=22050,
|
||||
channels=2,
|
||||
width=2,
|
||||
)
|
||||
for wav_url in wav_urls
|
||||
]
|
||||
for device_id in device_ids
|
||||
}
|
||||
|
||||
for urls in device_urls.values():
|
||||
# First URL should fail because it was overwritten by the others
|
||||
req = await client.get(urls[0])
|
||||
assert req.status == HTTPStatus.BAD_REQUEST
|
||||
|
||||
# All other URLs should succeed
|
||||
for url in urls[1:]:
|
||||
req = await client.get(url)
|
||||
assert req.status == HTTPStatus.OK
|
||||
|
|
Loading…
Reference in New Issue