Only use media path for TTS stream override (#152084)
parent
c5d552dc4a
commit
0acd77e60a
|
@ -18,6 +18,7 @@ import secrets
|
|||
from time import monotonic
|
||||
from typing import Any, Final, Generic, Protocol, TypeVar
|
||||
|
||||
import aiofiles
|
||||
from aiohttp import web
|
||||
import mutagen
|
||||
from mutagen.id3 import ID3, TextFrame as ID3Text
|
||||
|
@ -27,7 +28,6 @@ import voluptuous as vol
|
|||
from homeassistant.components import ffmpeg, websocket_api
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.components.media_source import (
|
||||
async_resolve_media,
|
||||
generate_media_source_id as ms_generate_media_source_id,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
|
@ -43,7 +43,6 @@ from homeassistant.core import (
|
|||
)
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
from homeassistant.helpers.network import get_url
|
||||
|
@ -503,7 +502,7 @@ class ResultStream:
|
|||
_manager: SpeechManager
|
||||
|
||||
# Override
|
||||
_override_media_id: str | None = None
|
||||
_override_media_path: Path | None = None
|
||||
|
||||
@cached_property
|
||||
def url(self) -> str:
|
||||
|
@ -556,7 +555,7 @@ class ResultStream:
|
|||
|
||||
async def async_stream_result(self) -> AsyncGenerator[bytes]:
|
||||
"""Get the stream of this result."""
|
||||
if self._override_media_id is not None:
|
||||
if self._override_media_path is not None:
|
||||
# Overridden
|
||||
async for chunk in self._async_stream_override_result():
|
||||
yield chunk
|
||||
|
@ -570,46 +569,49 @@ class ResultStream:
|
|||
|
||||
self.last_used = monotonic()
|
||||
|
||||
def async_override_result(self, media_id: str) -> None:
|
||||
"""Override the TTS stream with a different media id."""
|
||||
self._override_media_id = media_id
|
||||
def async_override_result(self, media_path: str | Path) -> None:
|
||||
"""Override the TTS stream with a different media path."""
|
||||
self._override_media_path = Path(media_path)
|
||||
|
||||
async def _async_stream_override_result(self) -> AsyncGenerator[bytes]:
|
||||
"""Get the stream of the overridden result."""
|
||||
assert self._override_media_id is not None
|
||||
media = await async_resolve_media(self.hass, self._override_media_id)
|
||||
assert self._override_media_path is not None
|
||||
|
||||
# Determine if we need to do audio conversion
|
||||
preferred_extension: str | None = self.options.get(ATTR_PREFERRED_FORMAT)
|
||||
sample_rate: int | None = self.options.get(ATTR_PREFERRED_SAMPLE_RATE)
|
||||
sample_channels: int | None = self.options.get(ATTR_PREFERRED_SAMPLE_CHANNELS)
|
||||
sample_bytes: int | None = self.options.get(ATTR_PREFERRED_SAMPLE_BYTES)
|
||||
preferred_format = self.options.get(ATTR_PREFERRED_FORMAT)
|
||||
to_sample_rate = self.options.get(ATTR_PREFERRED_SAMPLE_RATE)
|
||||
to_sample_channels = self.options.get(ATTR_PREFERRED_SAMPLE_CHANNELS)
|
||||
to_sample_bytes = self.options.get(ATTR_PREFERRED_SAMPLE_BYTES)
|
||||
|
||||
needs_conversion = (
|
||||
preferred_extension
|
||||
or (sample_rate is not None)
|
||||
or (sample_channels is not None)
|
||||
or (sample_bytes is not None)
|
||||
(preferred_format is not None)
|
||||
or (to_sample_rate is not None)
|
||||
or (to_sample_channels is not None)
|
||||
or (to_sample_bytes is not None)
|
||||
)
|
||||
|
||||
if not needs_conversion:
|
||||
# Stream directly from URL (no conversion)
|
||||
session = async_get_clientsession(self.hass)
|
||||
async with session.get(media.url) as response:
|
||||
async for chunk in response.content:
|
||||
# Read file directly (no conversion)
|
||||
async with aiofiles.open(self._override_media_path, "rb") as media_file:
|
||||
while True:
|
||||
chunk = await media_file.read(FFMPEG_CHUNK_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
return
|
||||
|
||||
# Use ffmpeg to convert audio to preferred format
|
||||
if not preferred_format:
|
||||
preferred_format = self._override_media_path.suffix[1:] # strip .
|
||||
|
||||
converted_audio = _async_convert_audio(
|
||||
self.hass,
|
||||
from_extension=None,
|
||||
audio_input=media.path or media.url,
|
||||
to_extension=preferred_extension,
|
||||
to_sample_rate=sample_rate,
|
||||
to_sample_channels=sample_channels,
|
||||
to_sample_bytes=sample_bytes,
|
||||
audio_input=self._override_media_path,
|
||||
to_extension=preferred_format,
|
||||
to_sample_rate=self.options.get(ATTR_PREFERRED_SAMPLE_RATE),
|
||||
to_sample_channels=self.options.get(ATTR_PREFERRED_SAMPLE_CHANNELS),
|
||||
to_sample_bytes=self.options.get(ATTR_PREFERRED_SAMPLE_BYTES),
|
||||
)
|
||||
async for chunk in converted_audio:
|
||||
yield chunk
|
||||
|
|
|
@ -12,7 +12,7 @@ import wave
|
|||
from freezegun.api import FrozenDateTimeFactory
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import ffmpeg, media_source, tts
|
||||
from homeassistant.components import ffmpeg, tts
|
||||
from homeassistant.components.media_player import (
|
||||
ATTR_MEDIA_ANNOUNCE,
|
||||
ATTR_MEDIA_CONTENT_ID,
|
||||
|
@ -43,7 +43,6 @@ from .common import (
|
|||
)
|
||||
|
||||
from tests.common import MockModule, async_mock_service, mock_integration, mock_platform
|
||||
from tests.test_util.aiohttp import AiohttpClientMocker
|
||||
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||
|
||||
ORIG_WRITE_TAGS = tts.SpeechManager.write_tags
|
||||
|
@ -2070,33 +2069,40 @@ async def test_async_internal_get_tts_audio_called(
|
|||
|
||||
|
||||
async def test_stream_override(
|
||||
hass: HomeAssistant,
|
||||
mock_tts_entity: MockTTSEntity,
|
||||
aioclient_mock: AiohttpClientMocker,
|
||||
hass: HomeAssistant, mock_tts_entity: MockTTSEntity
|
||||
) -> None:
|
||||
"""Test overriding streams with a media id."""
|
||||
"""Test overriding streams with a media path."""
|
||||
await mock_config_entry_setup(hass, mock_tts_entity)
|
||||
|
||||
stream = tts.async_create_stream(hass, mock_tts_entity.entity_id)
|
||||
stream.async_set_message("beer")
|
||||
stream.async_override_result("test-media-id")
|
||||
|
||||
url = "http://www.home-assistant.io/resolved.mp3"
|
||||
test_data = b"override-data"
|
||||
aioclient_mock.get(url, content=test_data)
|
||||
with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as wav_file:
|
||||
with wave.open(wav_file, "wb") as wav_writer:
|
||||
wav_writer.setframerate(16000)
|
||||
wav_writer.setsampwidth(2)
|
||||
wav_writer.setnchannels(1)
|
||||
wav_writer.writeframes(bytes(16000 * 2)) # 1 second @ 16Khz/mono
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.tts.async_resolve_media",
|
||||
return_value=media_source.PlayMedia(url=url, mime_type="audio/mp3"),
|
||||
):
|
||||
wav_file.seek(0)
|
||||
|
||||
stream.async_override_result(wav_file.name)
|
||||
result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
|
||||
assert result_data == test_data
|
||||
|
||||
# Verify the result
|
||||
with io.BytesIO(result_data) as wav_io, wave.open(wav_io, "rb") as wav_reader:
|
||||
assert wav_reader.getframerate() == 16000
|
||||
assert wav_reader.getsampwidth() == 2
|
||||
assert wav_reader.getnchannels() == 1
|
||||
assert wav_reader.readframes(wav_reader.getnframes()) == bytes(
|
||||
16000 * 2
|
||||
) # 1 second @ 16Khz/mono
|
||||
|
||||
|
||||
async def test_stream_override_with_conversion(
|
||||
hass: HomeAssistant, mock_tts_entity: MockTTSEntity
|
||||
) -> None:
|
||||
"""Test overriding streams with a media id that requires conversion."""
|
||||
"""Test overriding streams with a media path that requires conversion."""
|
||||
await mock_config_entry_setup(hass, mock_tts_entity)
|
||||
|
||||
stream = tts.async_create_stream(
|
||||
|
@ -2110,7 +2116,6 @@ async def test_stream_override_with_conversion(
|
|||
},
|
||||
)
|
||||
stream.async_set_message("beer")
|
||||
stream.async_override_result("test-media-id")
|
||||
|
||||
# Use a temp file here since ffmpeg will read it directly
|
||||
with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as wav_file:
|
||||
|
@ -2121,17 +2126,10 @@ async def test_stream_override_with_conversion(
|
|||
wav_writer.writeframes(bytes(16000 * 2)) # 1 second @ 16Khz/mono
|
||||
|
||||
wav_file.seek(0)
|
||||
stream.async_override_result(wav_file.name)
|
||||
result_data = b"".join([chunk async for chunk in stream.async_stream_result()])
|
||||
|
||||
url = f"file://{wav_file.name}"
|
||||
with patch(
|
||||
"homeassistant.components.tts.async_resolve_media",
|
||||
return_value=media_source.PlayMedia(url=url, mime_type="audio/wav"),
|
||||
):
|
||||
result_data = b"".join(
|
||||
[chunk async for chunk in stream.async_stream_result()]
|
||||
)
|
||||
|
||||
# Verify the preferred format
|
||||
# Verify the result has the preferred format
|
||||
with io.BytesIO(result_data) as wav_io, wave.open(wav_io, "rb") as wav_reader:
|
||||
assert wav_reader.getframerate() == 22050
|
||||
assert wav_reader.getsampwidth() == 2
|
||||
|
@ -2139,50 +2137,3 @@ async def test_stream_override_with_conversion(
|
|||
assert wav_reader.readframes(wav_reader.getnframes()) == bytes(
|
||||
22050 * 2 * 2
|
||||
) # 1 second @ 22.5Khz/stereo
|
||||
|
||||
|
||||
async def test_stream_override_with_conversion_path_preferred(
|
||||
hass: HomeAssistant, mock_tts_entity: MockTTSEntity
|
||||
) -> None:
|
||||
"""Test overriding streams with a media id that requires conversion and has a path."""
|
||||
await mock_config_entry_setup(hass, mock_tts_entity)
|
||||
|
||||
stream = tts.async_create_stream(
|
||||
hass,
|
||||
mock_tts_entity.entity_id,
|
||||
options={tts.ATTR_PREFERRED_FORMAT: "wav"},
|
||||
)
|
||||
stream.async_set_message("beer")
|
||||
stream.async_override_result("test-media-id")
|
||||
|
||||
# Use a temp file here since ffmpeg will read it directly
|
||||
with tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as wav_file:
|
||||
with wave.open(wav_file, "wb") as wav_writer:
|
||||
wav_writer.setframerate(16000)
|
||||
wav_writer.setsampwidth(2)
|
||||
wav_writer.setnchannels(1)
|
||||
wav_writer.writeframes(bytes(16000 * 2)) # 1 second @ 16Khz/mono
|
||||
|
||||
wav_file.seek(0)
|
||||
|
||||
# Path is preferred over URL
|
||||
with patch(
|
||||
"homeassistant.components.tts.async_resolve_media",
|
||||
return_value=media_source.PlayMedia(
|
||||
path=Path(wav_file.name),
|
||||
url="http://bad-url.com",
|
||||
mime_type="audio/wav",
|
||||
),
|
||||
):
|
||||
result_data = b"".join(
|
||||
[chunk async for chunk in stream.async_stream_result()]
|
||||
)
|
||||
|
||||
# Verify the preferred format
|
||||
with io.BytesIO(result_data) as wav_io, wave.open(wav_io, "rb") as wav_reader:
|
||||
assert wav_reader.getframerate() == 16000
|
||||
assert wav_reader.getsampwidth() == 2
|
||||
assert wav_reader.getnchannels() == 1
|
||||
assert wav_reader.readframes(wav_reader.getnframes()) == bytes(
|
||||
16000 * 2
|
||||
) # 1 second @ 16Khz/mono
|
||||
|
|
Loading…
Reference in New Issue