Add Wyoming satellite announce (#138221)

* Add Wyoming satellite announce

* Initialize when necessary
pull/138240/head
Michael Hansen 2025-02-10 14:36:20 -06:00 committed by GitHub
parent f83c8de8d3
commit 6bc6111771
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 187 additions and 11 deletions

View File

@ -24,18 +24,20 @@ from wyoming.tts import Synthesize, SynthesizeVoice
from wyoming.vad import VoiceStarted, VoiceStopped
from wyoming.wake import Detect, Detection
from homeassistant.components import assist_pipeline, intent, tts
from homeassistant.components import assist_pipeline, ffmpeg, intent, tts
from homeassistant.components.assist_pipeline import PipelineEvent
from homeassistant.components.assist_satellite import (
AssistSatelliteAnnouncement,
AssistSatelliteConfiguration,
AssistSatelliteEntity,
AssistSatelliteEntityDescription,
AssistSatelliteEntityFeature,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
from .const import DOMAIN
from .const import DOMAIN, SAMPLE_CHANNELS, SAMPLE_WIDTH
from .data import WyomingService
from .devices import SatelliteDevice
from .entity import WyomingSatelliteEntity
@ -49,6 +51,8 @@ _RESTART_SECONDS: Final = 3
_PING_TIMEOUT: Final = 5
_PING_SEND_DELAY: Final = 2
_PIPELINE_FINISH_TIMEOUT: Final = 1
_TTS_SAMPLE_RATE: Final = 22050
_ANNOUNCE_CHUNK_BYTES: Final = 2048 # 1024 samples
# Wyoming stage -> Assist stage
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
@ -83,6 +87,7 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
_attr_translation_key = "assist_satellite"
_attr_name = None
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE
def __init__(
self,
@ -116,6 +121,10 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
self.device.set_pipeline_listener(self._pipeline_changed)
self.device.set_audio_settings_listener(self._audio_settings_changed)
# For announcements
self._ffmpeg_manager: ffmpeg.FFmpegManager | None = None
self._played_event_received: asyncio.Event | None = None
@property
def pipeline_entity_id(self) -> str | None:
"""Return the entity ID of the pipeline to use for the next conversation."""
@ -131,9 +140,9 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
"""Options passed for text-to-speech."""
return {
tts.ATTR_PREFERRED_FORMAT: "wav",
tts.ATTR_PREFERRED_SAMPLE_RATE: 16000,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
tts.ATTR_PREFERRED_SAMPLE_RATE: _TTS_SAMPLE_RATE,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: SAMPLE_CHANNELS,
tts.ATTR_PREFERRED_SAMPLE_BYTES: SAMPLE_WIDTH,
}
async def async_added_to_hass(self) -> None:
@ -244,6 +253,76 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
)
)
async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
"""Announce media on the satellite.
Should block until the announcement is done playing.
"""
assert self._client is not None
if self._ffmpeg_manager is None:
self._ffmpeg_manager = ffmpeg.get_ffmpeg_manager(self.hass)
if self._played_event_received is None:
self._played_event_received = asyncio.Event()
self._played_event_received.clear()
await self._client.write_event(
AudioStart(
rate=_TTS_SAMPLE_RATE,
width=SAMPLE_WIDTH,
channels=SAMPLE_CHANNELS,
timestamp=0,
).event()
)
timestamp = 0
try:
# Use ffmpeg to convert to raw PCM audio with the appropriate format
proc = await asyncio.create_subprocess_exec(
self._ffmpeg_manager.binary,
"-i",
announcement.media_id,
"-f",
"s16le",
"-ac",
str(SAMPLE_CHANNELS),
"-ar",
str(_TTS_SAMPLE_RATE),
"-nostats",
"pipe:",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
close_fds=False, # use posix_spawn in CPython < 3.13
)
assert proc.stdout is not None
while True:
chunk_bytes = await proc.stdout.read(_ANNOUNCE_CHUNK_BYTES)
if not chunk_bytes:
break
chunk = AudioChunk(
rate=_TTS_SAMPLE_RATE,
width=SAMPLE_WIDTH,
channels=SAMPLE_CHANNELS,
audio=chunk_bytes,
timestamp=timestamp,
)
await self._client.write_event(chunk.event())
timestamp += chunk.milliseconds
finally:
await self._client.write_event(AudioStop().event())
if timestamp > 0:
# Wait the length of the audio or until we receive a played event
audio_seconds = timestamp / 1000
try:
async with asyncio.timeout(audio_seconds + 0.5):
await self._played_event_received.wait()
except TimeoutError:
# Older satellite clients will wait longer than necessary
_LOGGER.debug("Did not receive played event for announcement")
# -------------------------------------------------------------------------
def start_satellite(self) -> None:
@ -511,6 +590,9 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
elif Played.is_type(client_event.type):
# TTS response has finished playing on satellite
self.tts_response_finished()
if self._played_event_received is not None:
self._played_event_received.set()
else:
_LOGGER.debug("Unexpected event from satellite: %s", client_event)

View File

@ -7,7 +7,8 @@
"assist_satellite",
"assist_pipeline",
"intent",
"conversation"
"conversation",
"ffmpeg"
],
"documentation": "https://www.home-assistant.io/integrations/wyoming",
"integration_type": "service",

View File

@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable
import io
import tempfile
from typing import Any
from unittest.mock import patch
import wave
@ -17,17 +18,18 @@ from wyoming.info import Info
from wyoming.ping import Ping, Pong
from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import RunSatellite
from wyoming.snd import Played
from wyoming.timer import TimerCancelled, TimerFinished, TimerStarted, TimerUpdated
from wyoming.tts import Synthesize
from wyoming.vad import VoiceStarted, VoiceStopped
from wyoming.wake import Detect, Detection
from homeassistant.components import assist_pipeline, wyoming
from homeassistant.components import assist_pipeline, assist_satellite, wyoming
from homeassistant.components.wyoming.assist_satellite import WyomingAssistSatellite
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.const import STATE_ON
from homeassistant.core import HomeAssistant, State
from homeassistant.helpers import intent as intent_helper
from homeassistant.helpers import entity_registry as er, intent as intent_helper
from homeassistant.setup import async_setup_component
from . import SATELLITE_INFO, WAKE_WORD_INFO, MockAsyncTcpClient
@ -65,7 +67,7 @@ def get_test_wav() -> bytes:
wav_file.setnchannels(1)
# Single frame
wav_file.writeframes(b"123")
wav_file.writeframes(b"1234")
return wav_io.getvalue()
@ -73,10 +75,15 @@ def get_test_wav() -> bytes:
class SatelliteAsyncTcpClient(MockAsyncTcpClient):
"""Satellite AsyncTcpClient."""
def __init__(self, responses: list[Event]) -> None:
def __init__(
self, responses: list[Event], block_until_inject: bool = False
) -> None:
"""Initialize client."""
super().__init__(responses)
self.block_until_inject = block_until_inject
self._responses_ready = asyncio.Event()
self.connect_event = asyncio.Event()
self.run_satellite_event = asyncio.Event()
self.detect_event = asyncio.Event()
@ -188,6 +195,9 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
async def read_event(self) -> Event | None:
"""Receive."""
if self.block_until_inject and (not self.responses):
await self._responses_ready.wait()
event = await super().read_event()
# Keep sending audio chunks instead of None
@ -196,6 +206,7 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
def inject_event(self, event: Event) -> None:
"""Put an event in as the next response."""
self.responses = [event, *self.responses]
self._responses_ready.set()
async def test_satellite_pipeline(hass: HomeAssistant) -> None:
@ -416,7 +427,7 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
assert mock_client.tts_audio_chunk.rate == 22050
assert mock_client.tts_audio_chunk.width == 2
assert mock_client.tts_audio_chunk.channels == 1
assert mock_client.tts_audio_chunk.audio == b"123"
assert mock_client.tts_audio_chunk.audio == b"1234"
# Pipeline finished
pipeline_event_callback(
@ -1283,3 +1294,85 @@ async def test_timers(hass: HomeAssistant) -> None:
timer_finished = mock_client.timer_finished
assert timer_finished is not None
assert timer_finished.id == timer_started.id
async def test_announce(
hass: HomeAssistant, entity_registry: er.EntityRegistry
) -> None:
"""Test announce on satellite."""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
def async_process_play_media_url(hass: HomeAssistant, media_id: str) -> str:
# Don't create a URL
return media_id
with (
tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as temp_wav_file,
patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
),
patch(
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(responses=[], block_until_inject=True),
) as mock_client,
patch(
"homeassistant.components.assist_satellite.entity.async_process_play_media_url",
new=async_process_play_media_url,
),
):
# Use test WAV data for media
with wave.open(temp_wav_file.name, "wb") as wav_file:
wav_file.setframerate(22050)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(22050 * 2)) # 1 sec
temp_wav_file.seek(0)
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
assert device is not None
satellite_entry = next(
(
maybe_entry
for maybe_entry in er.async_entries_for_device(
entity_registry, device.device_id
)
if maybe_entry.domain == assist_satellite.DOMAIN
),
None,
)
assert satellite_entry is not None
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
announce_task = hass.async_create_background_task(
hass.services.async_call(
assist_satellite.DOMAIN,
"announce",
{
"entity_id": satellite_entry.entity_id,
"media_id": temp_wav_file.name,
},
blocking=True,
),
"wyoming_satellite_announce",
)
# Wait for audio to come from ffmpeg
async with asyncio.timeout(1):
await mock_client.tts_audio_start_event.wait()
await mock_client.tts_audio_chunk_event.wait()
await mock_client.tts_audio_stop_event.wait()
# Stop announcement from blocking
mock_client.inject_event(Played().event())
await announce_task
# Stop the satellite
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()