diff --git a/homeassistant/components/wyoming/assist_satellite.py b/homeassistant/components/wyoming/assist_satellite.py index d43af2d21b9..5440b2bebeb 100644 --- a/homeassistant/components/wyoming/assist_satellite.py +++ b/homeassistant/components/wyoming/assist_satellite.py @@ -24,18 +24,20 @@ from wyoming.tts import Synthesize, SynthesizeVoice from wyoming.vad import VoiceStarted, VoiceStopped from wyoming.wake import Detect, Detection -from homeassistant.components import assist_pipeline, intent, tts +from homeassistant.components import assist_pipeline, ffmpeg, intent, tts from homeassistant.components.assist_pipeline import PipelineEvent from homeassistant.components.assist_satellite import ( + AssistSatelliteAnnouncement, AssistSatelliteConfiguration, AssistSatelliteEntity, AssistSatelliteEntityDescription, + AssistSatelliteEntityFeature, ) from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant, callback from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback -from .const import DOMAIN +from .const import DOMAIN, SAMPLE_CHANNELS, SAMPLE_WIDTH from .data import WyomingService from .devices import SatelliteDevice from .entity import WyomingSatelliteEntity @@ -49,6 +51,8 @@ _RESTART_SECONDS: Final = 3 _PING_TIMEOUT: Final = 5 _PING_SEND_DELAY: Final = 2 _PIPELINE_FINISH_TIMEOUT: Final = 1 +_TTS_SAMPLE_RATE: Final = 22050 +_ANNOUNCE_CHUNK_BYTES: Final = 2048 # 1024 samples # Wyoming stage -> Assist stage _STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = { @@ -83,6 +87,7 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity): entity_description = AssistSatelliteEntityDescription(key="assist_satellite") _attr_translation_key = "assist_satellite" _attr_name = None + _attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE def __init__( self, @@ -116,6 +121,10 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity): self.device.set_pipeline_listener(self._pipeline_changed) self.device.set_audio_settings_listener(self._audio_settings_changed) + # For announcements + self._ffmpeg_manager: ffmpeg.FFmpegManager | None = None + self._played_event_received: asyncio.Event | None = None + @property def pipeline_entity_id(self) -> str | None: """Return the entity ID of the pipeline to use for the next conversation.""" @@ -131,9 +140,9 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity): """Options passed for text-to-speech.""" return { tts.ATTR_PREFERRED_FORMAT: "wav", - tts.ATTR_PREFERRED_SAMPLE_RATE: 16000, - tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1, - tts.ATTR_PREFERRED_SAMPLE_BYTES: 2, + tts.ATTR_PREFERRED_SAMPLE_RATE: _TTS_SAMPLE_RATE, + tts.ATTR_PREFERRED_SAMPLE_CHANNELS: SAMPLE_CHANNELS, + tts.ATTR_PREFERRED_SAMPLE_BYTES: SAMPLE_WIDTH, } async def async_added_to_hass(self) -> None: @@ -244,6 +253,76 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity): ) ) + async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None: + """Announce media on the satellite. + + Should block until the announcement is done playing. + """ + assert self._client is not None + + if self._ffmpeg_manager is None: + self._ffmpeg_manager = ffmpeg.get_ffmpeg_manager(self.hass) + + if self._played_event_received is None: + self._played_event_received = asyncio.Event() + + self._played_event_received.clear() + await self._client.write_event( + AudioStart( + rate=_TTS_SAMPLE_RATE, + width=SAMPLE_WIDTH, + channels=SAMPLE_CHANNELS, + timestamp=0, + ).event() + ) + + timestamp = 0 + try: + # Use ffmpeg to convert to raw PCM audio with the appropriate format + proc = await asyncio.create_subprocess_exec( + self._ffmpeg_manager.binary, + "-i", + announcement.media_id, + "-f", + "s16le", + "-ac", + str(SAMPLE_CHANNELS), + "-ar", + str(_TTS_SAMPLE_RATE), + "-nostats", + "pipe:", + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + close_fds=False, # use posix_spawn in CPython < 3.13 + ) + assert proc.stdout is not None + while True: + chunk_bytes = await proc.stdout.read(_ANNOUNCE_CHUNK_BYTES) + if not chunk_bytes: + break + + chunk = AudioChunk( + rate=_TTS_SAMPLE_RATE, + width=SAMPLE_WIDTH, + channels=SAMPLE_CHANNELS, + audio=chunk_bytes, + timestamp=timestamp, + ) + await self._client.write_event(chunk.event()) + + timestamp += chunk.milliseconds + finally: + await self._client.write_event(AudioStop().event()) + if timestamp > 0: + # Wait the length of the audio or until we receive a played event + audio_seconds = timestamp / 1000 + try: + async with asyncio.timeout(audio_seconds + 0.5): + await self._played_event_received.wait() + except TimeoutError: + # Older satellite clients will wait longer than necessary + _LOGGER.debug("Did not receive played event for announcement") + # ------------------------------------------------------------------------- def start_satellite(self) -> None: @@ -511,6 +590,9 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity): elif Played.is_type(client_event.type): # TTS response has finished playing on satellite self.tts_response_finished() + + if self._played_event_received is not None: + self._played_event_received.set() else: _LOGGER.debug("Unexpected event from satellite: %s", client_event) diff --git a/homeassistant/components/wyoming/manifest.json b/homeassistant/components/wyoming/manifest.json index b837d2a9e76..d75b70dffa8 100644 --- a/homeassistant/components/wyoming/manifest.json +++ b/homeassistant/components/wyoming/manifest.json @@ -7,7 +7,8 @@ "assist_satellite", "assist_pipeline", "intent", - "conversation" + "conversation", + "ffmpeg" ], "documentation": "https://www.home-assistant.io/integrations/wyoming", "integration_type": "service", diff --git a/tests/components/wyoming/test_satellite.py b/tests/components/wyoming/test_satellite.py index f293f976242..0e4bb3da78c 100644 --- a/tests/components/wyoming/test_satellite.py +++ b/tests/components/wyoming/test_satellite.py @@ -5,6 +5,7 @@ from __future__ import annotations import asyncio from collections.abc import Callable import io +import tempfile from typing import Any from unittest.mock import patch import wave @@ -17,17 +18,18 @@ from wyoming.info import Info from wyoming.ping import Ping, Pong from wyoming.pipeline import PipelineStage, RunPipeline from wyoming.satellite import RunSatellite +from wyoming.snd import Played from wyoming.timer import TimerCancelled, TimerFinished, TimerStarted, TimerUpdated from wyoming.tts import Synthesize from wyoming.vad import VoiceStarted, VoiceStopped from wyoming.wake import Detect, Detection -from homeassistant.components import assist_pipeline, wyoming +from homeassistant.components import assist_pipeline, assist_satellite, wyoming from homeassistant.components.wyoming.assist_satellite import WyomingAssistSatellite from homeassistant.components.wyoming.devices import SatelliteDevice from homeassistant.const import STATE_ON from homeassistant.core import HomeAssistant, State -from homeassistant.helpers import intent as intent_helper +from homeassistant.helpers import entity_registry as er, intent as intent_helper from homeassistant.setup import async_setup_component from . import SATELLITE_INFO, WAKE_WORD_INFO, MockAsyncTcpClient @@ -65,7 +67,7 @@ def get_test_wav() -> bytes: wav_file.setnchannels(1) # Single frame - wav_file.writeframes(b"123") + wav_file.writeframes(b"1234") return wav_io.getvalue() @@ -73,10 +75,15 @@ def get_test_wav() -> bytes: class SatelliteAsyncTcpClient(MockAsyncTcpClient): """Satellite AsyncTcpClient.""" - def __init__(self, responses: list[Event]) -> None: + def __init__( + self, responses: list[Event], block_until_inject: bool = False + ) -> None: """Initialize client.""" super().__init__(responses) + self.block_until_inject = block_until_inject + self._responses_ready = asyncio.Event() + self.connect_event = asyncio.Event() self.run_satellite_event = asyncio.Event() self.detect_event = asyncio.Event() @@ -188,6 +195,9 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient): async def read_event(self) -> Event | None: """Receive.""" + if self.block_until_inject and (not self.responses): + await self._responses_ready.wait() + event = await super().read_event() # Keep sending audio chunks instead of None @@ -196,6 +206,7 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient): def inject_event(self, event: Event) -> None: """Put an event in as the next response.""" self.responses = [event, *self.responses] + self._responses_ready.set() async def test_satellite_pipeline(hass: HomeAssistant) -> None: @@ -416,7 +427,7 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None: assert mock_client.tts_audio_chunk.rate == 22050 assert mock_client.tts_audio_chunk.width == 2 assert mock_client.tts_audio_chunk.channels == 1 - assert mock_client.tts_audio_chunk.audio == b"123" + assert mock_client.tts_audio_chunk.audio == b"1234" # Pipeline finished pipeline_event_callback( @@ -1283,3 +1294,85 @@ async def test_timers(hass: HomeAssistant) -> None: timer_finished = mock_client.timer_finished assert timer_finished is not None assert timer_finished.id == timer_started.id + + +async def test_announce( + hass: HomeAssistant, entity_registry: er.EntityRegistry +) -> None: + """Test announce on satellite.""" + assert await async_setup_component(hass, assist_pipeline.DOMAIN, {}) + + def async_process_play_media_url(hass: HomeAssistant, media_id: str) -> str: + # Don't create a URL + return media_id + + with ( + tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as temp_wav_file, + patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), + patch( + "homeassistant.components.wyoming.assist_satellite.AsyncTcpClient", + SatelliteAsyncTcpClient(responses=[], block_until_inject=True), + ) as mock_client, + patch( + "homeassistant.components.assist_satellite.entity.async_process_play_media_url", + new=async_process_play_media_url, + ), + ): + # Use test WAV data for media + with wave.open(temp_wav_file.name, "wb") as wav_file: + wav_file.setframerate(22050) + wav_file.setsampwidth(2) + wav_file.setnchannels(1) + wav_file.writeframes(bytes(22050 * 2)) # 1 sec + + temp_wav_file.seek(0) + + entry = await setup_config_entry(hass) + device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device + assert device is not None + + satellite_entry = next( + ( + maybe_entry + for maybe_entry in er.async_entries_for_device( + entity_registry, device.device_id + ) + if maybe_entry.domain == assist_satellite.DOMAIN + ), + None, + ) + assert satellite_entry is not None + + async with asyncio.timeout(1): + await mock_client.connect_event.wait() + await mock_client.run_satellite_event.wait() + + announce_task = hass.async_create_background_task( + hass.services.async_call( + assist_satellite.DOMAIN, + "announce", + { + "entity_id": satellite_entry.entity_id, + "media_id": temp_wav_file.name, + }, + blocking=True, + ), + "wyoming_satellite_announce", + ) + + # Wait for audio to come from ffmpeg + async with asyncio.timeout(1): + await mock_client.tts_audio_start_event.wait() + await mock_client.tts_audio_chunk_event.wait() + await mock_client.tts_audio_stop_event.wait() + + # Stop announcement from blocking + mock_client.inject_event(Played().event()) + await announce_task + + # Stop the satellite + await hass.config_entries.async_unload(entry.entry_id) + await hass.async_block_till_done()