Only use media path for TTS stream override (#152084)

pull/152130/head
Michael Hansen 2025-09-11 12:46:36 -05:00 committed by GitHub
parent c5d552dc4a
commit 0acd77e60a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 54 additions and 101 deletions

View File

@ -18,6 +18,7 @@ import secrets
from time import monotonic
from typing import Any, Final, Generic, Protocol, TypeVar
import aiofiles
from aiohttp import web
import mutagen
from mutagen.id3 import ID3, TextFrame as ID3Text
@ -27,7 +28,6 @@ import voluptuous as vol
from homeassistant.components import ffmpeg, websocket_api
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.media_source import (
async_resolve_media,
generate_media_source_id as ms_generate_media_source_id,
)
from homeassistant.config_entries import ConfigEntry
@ -43,7 +43,6 @@ from homeassistant.core import (
)
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.event import async_call_later
from homeassistant.helpers.network import get_url
@ -503,7 +502,7 @@ class ResultStream:
_manager: SpeechManager
# Override
_override_media_id: str | None = None
_override_media_path: Path | None = None
@cached_property
def url(self) -> str:
@ -556,7 +555,7 @@ class ResultStream:
async def async_stream_result(self) -> AsyncGenerator[bytes]:
"""Get the stream of this result."""
if self._override_media_id is not None:
if self._override_media_path is not None:
# Overridden
async for chunk in self._async_stream_override_result():
yield chunk
@ -570,46 +569,49 @@ class ResultStream:
self.last_used = monotonic()
def async_override_result(self, media_id: str) -> None:
"""Override the TTS stream with a different media id."""
self._override_media_id = media_id
def async_override_result(self, media_path: str | Path) -> None:
"""Override the TTS stream with a different media path."""
self._override_media_path = Path(media_path)
async def _async_stream_override_result(self) -> AsyncGenerator[bytes]:
"""Get the stream of the overridden result."""
assert self._override_media_id is not None
media = await async_resolve_media(self.hass, self._override_media_id)
assert self._override_media_path is not None
# Determine if we need to do audio conversion
preferred_extension: str | None = self.options.get(ATTR_PREFERRED_FORMAT)
sample_rate: int | None = self.options.get(ATTR_PREFERRED_SAMPLE_RATE)
sample_channels: int | None = self.options.get(ATTR_PREFERRED_SAMPLE_CHANNELS)
sample_bytes: int | None = self.options.get(ATTR_PREFERRED_SAMPLE_BYTES)
preferred_format = self.options.get(ATTR_PREFERRED_FORMAT)
to_sample_rate = self.options.get(ATTR_PREFERRED_SAMPLE_RATE)
to_sample_channels = self.options.get(ATTR_PREFERRED_SAMPLE_CHANNELS)
to_sample_bytes = self.options.get(ATTR_PREFERRED_SAMPLE_BYTES)
needs_conversion = (
preferred_extension
or (sample_rate is not None)
or (sample_channels is not None)
or (sample_bytes is not None)
(preferred_format is not None)
or (to_sample_rate is not None)
or (to_sample_channels is not None)
or (to_sample_bytes is not None)
)
if not needs_conversion:
# Stream directly from URL (no conversion)
session = async_get_clientsession(self.hass)
async with session.get(media.url) as response:
async for chunk in response.content:
# Read file directly (no conversion)
async with aiofiles.open(self._override_media_path, "rb") as media_file:
while True:
chunk = await media_file.read(FFMPEG_CHUNK_SIZE)
if not chunk:
break
yield chunk
return
# Use ffmpeg to convert audio to preferred format
if not preferred_format:
preferred_format = self._override_media_path.suffix[1:] # strip .
converted_audio = _async_convert_audio(
self.hass,
from_extension=None,
audio_input=media.path or media.url,
to_extension=preferred_extension,
to_sample_rate=sample_rate,
to_sample_channels=sample_channels,
to_sample_bytes=sample_bytes,
audio_input=self._override_media_path,
to_extension=preferred_format,
to_sample_rate=self.options.get(ATTR_PREFERRED_SAMPLE_RATE),
to_sample_channels=self.options.get(ATTR_PREFERRED_SAMPLE_CHANNELS),
to_sample_bytes=self.options.get(ATTR_PREFERRED_SAMPLE_BYTES),
)
async for chunk in converted_audio:
yield chunk

View File

@ -12,7 +12,7 @@ import wave
from freezegun.api import FrozenDateTimeFactory
import pytest
from homeassistant.components import ffmpeg, media_source, tts
from homeassistant.components import ffmpeg, tts
from homeassistant.components.media_player import (
ATTR_MEDIA_ANNOUNCE,
ATTR_MEDIA_CONTENT_ID,
@ -43,7 +43,6 @@ from .common import (
)
from tests.common import MockModule, async_mock_service, mock_integration, mock_platform
from tests.test_util.aiohttp import AiohttpClientMocker
from tests.typing import ClientSessionGenerator, WebSocketGenerator
ORIG_WRITE_TAGS = tts.SpeechManager.write_tags
@ -2070,33 +2069,40 @@ async def test_async_internal_get_tts_audio_called(
async def test_stream_override(
hass: HomeAssistant,
mock_tts_entity: MockTTSEntity,
aioclient_mock: AiohttpClientMocker,
hass: HomeAssistant, mock_tts_entity: MockTTSEntity
) -> None:
"""Test overriding streams with a media id."""
"""Test overriding streams with a media path."""
await mock_config_entry_setup(hass, mock_tts_entity)
stream = tts.async_create_stream(hass, mock_tts_entity.entity_id)
stream.async_set_message("beer")
stream.async_override_result("test-media-id")
url = "http://www.home-assistant.io/resolved.mp3"
test_data = b"override-data"
aioclient_mock.get(url, content=test_data)
with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as wav_file:
with wave.open(wav_file, "wb") as wav_writer:
wav_writer.setframerate(16000)
wav_writer.setsampwidth(2)
wav_writer.setnchannels(1)
wav_writer.writeframes(bytes(16000 * 2)) # 1 second @ 16Khz/mono
with patch(
"homeassistant.components.tts.async_resolve_media",
return_value=media_source.PlayMedia(url=url, mime_type="audio/mp3"),
):
wav_file.seek(0)
stream.async_override_result(wav_file.name)
result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
assert result_data == test_data
# Verify the result
with io.BytesIO(result_data) as wav_io, wave.open(wav_io, "rb") as wav_reader:
assert wav_reader.getframerate() == 16000
assert wav_reader.getsampwidth() == 2
assert wav_reader.getnchannels() == 1
assert wav_reader.readframes(wav_reader.getnframes()) == bytes(
16000 * 2
) # 1 second @ 16Khz/mono
async def test_stream_override_with_conversion(
hass: HomeAssistant, mock_tts_entity: MockTTSEntity
) -> None:
"""Test overriding streams with a media id that requires conversion."""
"""Test overriding streams with a media path that requires conversion."""
await mock_config_entry_setup(hass, mock_tts_entity)
stream = tts.async_create_stream(
@ -2110,7 +2116,6 @@ async def test_stream_override_with_conversion(
},
)
stream.async_set_message("beer")
stream.async_override_result("test-media-id")
# Use a temp file here since ffmpeg will read it directly
with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as wav_file:
@ -2121,17 +2126,10 @@ async def test_stream_override_with_conversion(
wav_writer.writeframes(bytes(16000 * 2)) # 1 second @ 16Khz/mono
wav_file.seek(0)
stream.async_override_result(wav_file.name)
result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
url = f"file://{wav_file.name}"
with patch(
"homeassistant.components.tts.async_resolve_media",
return_value=media_source.PlayMedia(url=url, mime_type="audio/wav"),
):
result_data = b"".join(
[chunk async for chunk in stream.async_stream_result()]
)
# Verify the preferred format
# Verify the result has the preferred format
with io.BytesIO(result_data) as wav_io, wave.open(wav_io, "rb") as wav_reader:
assert wav_reader.getframerate() == 22050
assert wav_reader.getsampwidth() == 2
@ -2139,50 +2137,3 @@ async def test_stream_override_with_conversion(
assert wav_reader.readframes(wav_reader.getnframes()) == bytes(
22050 * 2 * 2
) # 1 second @ 22.5Khz/stereo
async def test_stream_override_with_conversion_path_preferred(
hass: HomeAssistant, mock_tts_entity: MockTTSEntity
) -> None:
"""Test overriding streams with a media id that requires conversion and has a path."""
await mock_config_entry_setup(hass, mock_tts_entity)
stream = tts.async_create_stream(
hass,
mock_tts_entity.entity_id,
options={tts.ATTR_PREFERRED_FORMAT: "wav"},
)
stream.async_set_message("beer")
stream.async_override_result("test-media-id")
# Use a temp file here since ffmpeg will read it directly
with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as wav_file:
with wave.open(wav_file, "wb") as wav_writer:
wav_writer.setframerate(16000)
wav_writer.setsampwidth(2)
wav_writer.setnchannels(1)
wav_writer.writeframes(bytes(16000 * 2)) # 1 second @ 16Khz/mono
wav_file.seek(0)
# Path is preferred over URL
with patch(
"homeassistant.components.tts.async_resolve_media",
return_value=media_source.PlayMedia(
path=Path(wav_file.name),
url="http://bad-url.com",
mime_type="audio/wav",
),
):
result_data = b"".join(
[chunk async for chunk in stream.async_stream_result()]
)
# Verify the preferred format
with io.BytesIO(result_data) as wav_io, wave.open(wav_io, "rb") as wav_reader:
assert wav_reader.getframerate() == 16000
assert wav_reader.getsampwidth() == 2
assert wav_reader.getnchannels() == 1
assert wav_reader.readframes(wav_reader.getnframes()) == bytes(
16000 * 2
) # 1 second @ 16Khz/mono