From 0acd77e60ad1efae4ed472fe25ff3f45e5360076 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Thu, 11 Sep 2025 12:46:36 -0500 Subject: [PATCH] Only use media path for TTS stream override (#152084) --- homeassistant/components/tts/__init__.py | 56 +++++++------- tests/components/tts/test_init.py | 99 ++++++------------------ 2 files changed, 54 insertions(+), 101 deletions(-) diff --git a/homeassistant/components/tts/__init__.py b/homeassistant/components/tts/__init__.py index f05b98a34675..f1ffc7e0aada 100644 --- a/homeassistant/components/tts/__init__.py +++ b/homeassistant/components/tts/__init__.py @@ -18,6 +18,7 @@ import secrets from time import monotonic from typing import Any, Final, Generic, Protocol, TypeVar +import aiofiles from aiohttp import web import mutagen from mutagen.id3 import ID3, TextFrame as ID3Text @@ -27,7 +28,6 @@ import voluptuous as vol from homeassistant.components import ffmpeg, websocket_api from homeassistant.components.http import HomeAssistantView from homeassistant.components.media_source import ( - async_resolve_media, generate_media_source_id as ms_generate_media_source_id, ) from homeassistant.config_entries import ConfigEntry @@ -43,7 +43,6 @@ from homeassistant.core import ( ) from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import config_validation as cv -from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.event import async_call_later from homeassistant.helpers.network import get_url @@ -503,7 +502,7 @@ class ResultStream: _manager: SpeechManager # Override - _override_media_id: str | None = None + _override_media_path: Path | None = None @cached_property def url(self) -> str: @@ -556,7 +555,7 @@ class ResultStream: async def async_stream_result(self) -> AsyncGenerator[bytes]: """Get the stream of this result.""" - if self._override_media_id is not None: + if self._override_media_path is not None: # Overridden async for chunk in self._async_stream_override_result(): yield chunk @@ -570,46 +569,49 @@ class ResultStream: self.last_used = monotonic() - def async_override_result(self, media_id: str) -> None: - """Override the TTS stream with a different media id.""" - self._override_media_id = media_id + def async_override_result(self, media_path: str | Path) -> None: + """Override the TTS stream with a different media path.""" + self._override_media_path = Path(media_path) async def _async_stream_override_result(self) -> AsyncGenerator[bytes]: """Get the stream of the overridden result.""" - assert self._override_media_id is not None - media = await async_resolve_media(self.hass, self._override_media_id) + assert self._override_media_path is not None - # Determine if we need to do audio conversion - preferred_extension: str | None = self.options.get(ATTR_PREFERRED_FORMAT) - sample_rate: int | None = self.options.get(ATTR_PREFERRED_SAMPLE_RATE) - sample_channels: int | None = self.options.get(ATTR_PREFERRED_SAMPLE_CHANNELS) - sample_bytes: int | None = self.options.get(ATTR_PREFERRED_SAMPLE_BYTES) + preferred_format = self.options.get(ATTR_PREFERRED_FORMAT) + to_sample_rate = self.options.get(ATTR_PREFERRED_SAMPLE_RATE) + to_sample_channels = self.options.get(ATTR_PREFERRED_SAMPLE_CHANNELS) + to_sample_bytes = self.options.get(ATTR_PREFERRED_SAMPLE_BYTES) needs_conversion = ( - preferred_extension - or (sample_rate is not None) - or (sample_channels is not None) - or (sample_bytes is not None) + (preferred_format is not None) + or (to_sample_rate is not None) + or (to_sample_channels is not None) + or (to_sample_bytes is not None) ) if not needs_conversion: - # Stream directly from URL (no conversion) - session = async_get_clientsession(self.hass) - async with session.get(media.url) as response: - async for chunk in response.content: + # Read file directly (no conversion) + async with aiofiles.open(self._override_media_path, "rb") as media_file: + while True: + chunk = await media_file.read(FFMPEG_CHUNK_SIZE) + if not chunk: + break yield chunk return # Use ffmpeg to convert audio to preferred format + if not preferred_format: + preferred_format = self._override_media_path.suffix[1:] # strip . + converted_audio = _async_convert_audio( self.hass, from_extension=None, - audio_input=media.path or media.url, - to_extension=preferred_extension, - to_sample_rate=sample_rate, - to_sample_channels=sample_channels, - to_sample_bytes=sample_bytes, + audio_input=self._override_media_path, + to_extension=preferred_format, + to_sample_rate=self.options.get(ATTR_PREFERRED_SAMPLE_RATE), + to_sample_channels=self.options.get(ATTR_PREFERRED_SAMPLE_CHANNELS), + to_sample_bytes=self.options.get(ATTR_PREFERRED_SAMPLE_BYTES), ) async for chunk in converted_audio: yield chunk diff --git a/tests/components/tts/test_init.py b/tests/components/tts/test_init.py index 21cb6528480b..dc50f18d5e19 100644 --- a/tests/components/tts/test_init.py +++ b/tests/components/tts/test_init.py @@ -12,7 +12,7 @@ import wave from freezegun.api import FrozenDateTimeFactory import pytest -from homeassistant.components import ffmpeg, media_source, tts +from homeassistant.components import ffmpeg, tts from homeassistant.components.media_player import ( ATTR_MEDIA_ANNOUNCE, ATTR_MEDIA_CONTENT_ID, @@ -43,7 +43,6 @@ from .common import ( ) from tests.common import MockModule, async_mock_service, mock_integration, mock_platform -from tests.test_util.aiohttp import AiohttpClientMocker from tests.typing import ClientSessionGenerator, WebSocketGenerator ORIG_WRITE_TAGS = tts.SpeechManager.write_tags @@ -2070,33 +2069,40 @@ async def test_async_internal_get_tts_audio_called( async def test_stream_override( - hass: HomeAssistant, - mock_tts_entity: MockTTSEntity, - aioclient_mock: AiohttpClientMocker, + hass: HomeAssistant, mock_tts_entity: MockTTSEntity ) -> None: - """Test overriding streams with a media id.""" + """Test overriding streams with a media path.""" await mock_config_entry_setup(hass, mock_tts_entity) stream = tts.async_create_stream(hass, mock_tts_entity.entity_id) stream.async_set_message("beer") - stream.async_override_result("test-media-id") - url = "http://www.home-assistant.io/resolved.mp3" - test_data = b"override-data" - aioclient_mock.get(url, content=test_data) + with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as wav_file: + with wave.open(wav_file, "wb") as wav_writer: + wav_writer.setframerate(16000) + wav_writer.setsampwidth(2) + wav_writer.setnchannels(1) + wav_writer.writeframes(bytes(16000 * 2)) # 1 second @ 16Khz/mono - with patch( - "homeassistant.components.tts.async_resolve_media", - return_value=media_source.PlayMedia(url=url, mime_type="audio/mp3"), - ): + wav_file.seek(0) + + stream.async_override_result(wav_file.name) result_data = b"".join([chunk async for chunk in stream.async_stream_result()]) - assert result_data == test_data + + # Verify the result + with io.BytesIO(result_data) as wav_io, wave.open(wav_io, "rb") as wav_reader: + assert wav_reader.getframerate() == 16000 + assert wav_reader.getsampwidth() == 2 + assert wav_reader.getnchannels() == 1 + assert wav_reader.readframes(wav_reader.getnframes()) == bytes( + 16000 * 2 + ) # 1 second @ 16Khz/mono async def test_stream_override_with_conversion( hass: HomeAssistant, mock_tts_entity: MockTTSEntity ) -> None: - """Test overriding streams with a media id that requires conversion.""" + """Test overriding streams with a media path that requires conversion.""" await mock_config_entry_setup(hass, mock_tts_entity) stream = tts.async_create_stream( @@ -2110,7 +2116,6 @@ async def test_stream_override_with_conversion( }, ) stream.async_set_message("beer") - stream.async_override_result("test-media-id") # Use a temp file here since ffmpeg will read it directly with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as wav_file: @@ -2121,17 +2126,10 @@ async def test_stream_override_with_conversion( wav_writer.writeframes(bytes(16000 * 2)) # 1 second @ 16Khz/mono wav_file.seek(0) + stream.async_override_result(wav_file.name) + result_data = b"".join([chunk async for chunk in stream.async_stream_result()]) - url = f"file://{wav_file.name}" - with patch( - "homeassistant.components.tts.async_resolve_media", - return_value=media_source.PlayMedia(url=url, mime_type="audio/wav"), - ): - result_data = b"".join( - [chunk async for chunk in stream.async_stream_result()] - ) - - # Verify the preferred format + # Verify the result has the preferred format with io.BytesIO(result_data) as wav_io, wave.open(wav_io, "rb") as wav_reader: assert wav_reader.getframerate() == 22050 assert wav_reader.getsampwidth() == 2 @@ -2139,50 +2137,3 @@ async def test_stream_override_with_conversion( assert wav_reader.readframes(wav_reader.getnframes()) == bytes( 22050 * 2 * 2 ) # 1 second @ 22.5Khz/stereo - - -async def test_stream_override_with_conversion_path_preferred( - hass: HomeAssistant, mock_tts_entity: MockTTSEntity -) -> None: - """Test overriding streams with a media id that requires conversion and has a path.""" - await mock_config_entry_setup(hass, mock_tts_entity) - - stream = tts.async_create_stream( - hass, - mock_tts_entity.entity_id, - options={tts.ATTR_PREFERRED_FORMAT: "wav"}, - ) - stream.async_set_message("beer") - stream.async_override_result("test-media-id") - - # Use a temp file here since ffmpeg will read it directly - with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as wav_file: - with wave.open(wav_file, "wb") as wav_writer: - wav_writer.setframerate(16000) - wav_writer.setsampwidth(2) - wav_writer.setnchannels(1) - wav_writer.writeframes(bytes(16000 * 2)) # 1 second @ 16Khz/mono - - wav_file.seek(0) - - # Path is preferred over URL - with patch( - "homeassistant.components.tts.async_resolve_media", - return_value=media_source.PlayMedia( - path=Path(wav_file.name), - url="http://bad-url.com", - mime_type="audio/wav", - ), - ): - result_data = b"".join( - [chunk async for chunk in stream.async_stream_result()] - ) - - # Verify the preferred format - with io.BytesIO(result_data) as wav_io, wave.open(wav_io, "rb") as wav_reader: - assert wav_reader.getframerate() == 16000 - assert wav_reader.getsampwidth() == 2 - assert wav_reader.getnchannels() == 1 - assert wav_reader.readframes(wav_reader.getnframes()) == bytes( - 16000 * 2 - ) # 1 second @ 16Khz/mono