Add Wyoming satellite announce (#138221)
* Add Wyoming satellite announce * Initialize when necessarypull/138240/head
parent
f83c8de8d3
commit
6bc6111771
|
@ -24,18 +24,20 @@ from wyoming.tts import Synthesize, SynthesizeVoice
|
|||
from wyoming.vad import VoiceStarted, VoiceStopped
|
||||
from wyoming.wake import Detect, Detection
|
||||
|
||||
from homeassistant.components import assist_pipeline, intent, tts
|
||||
from homeassistant.components import assist_pipeline, ffmpeg, intent, tts
|
||||
from homeassistant.components.assist_pipeline import PipelineEvent
|
||||
from homeassistant.components.assist_satellite import (
|
||||
AssistSatelliteAnnouncement,
|
||||
AssistSatelliteConfiguration,
|
||||
AssistSatelliteEntity,
|
||||
AssistSatelliteEntityDescription,
|
||||
AssistSatelliteEntityFeature,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||
|
||||
from .const import DOMAIN
|
||||
from .const import DOMAIN, SAMPLE_CHANNELS, SAMPLE_WIDTH
|
||||
from .data import WyomingService
|
||||
from .devices import SatelliteDevice
|
||||
from .entity import WyomingSatelliteEntity
|
||||
|
@ -49,6 +51,8 @@ _RESTART_SECONDS: Final = 3
|
|||
_PING_TIMEOUT: Final = 5
|
||||
_PING_SEND_DELAY: Final = 2
|
||||
_PIPELINE_FINISH_TIMEOUT: Final = 1
|
||||
_TTS_SAMPLE_RATE: Final = 22050
|
||||
_ANNOUNCE_CHUNK_BYTES: Final = 2048 # 1024 samples
|
||||
|
||||
# Wyoming stage -> Assist stage
|
||||
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
|
||||
|
@ -83,6 +87,7 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
|||
entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
|
||||
_attr_translation_key = "assist_satellite"
|
||||
_attr_name = None
|
||||
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -116,6 +121,10 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
|||
self.device.set_pipeline_listener(self._pipeline_changed)
|
||||
self.device.set_audio_settings_listener(self._audio_settings_changed)
|
||||
|
||||
# For announcements
|
||||
self._ffmpeg_manager: ffmpeg.FFmpegManager | None = None
|
||||
self._played_event_received: asyncio.Event | None = None
|
||||
|
||||
@property
|
||||
def pipeline_entity_id(self) -> str | None:
|
||||
"""Return the entity ID of the pipeline to use for the next conversation."""
|
||||
|
@ -131,9 +140,9 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
|||
"""Options passed for text-to-speech."""
|
||||
return {
|
||||
tts.ATTR_PREFERRED_FORMAT: "wav",
|
||||
tts.ATTR_PREFERRED_SAMPLE_RATE: 16000,
|
||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
|
||||
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
|
||||
tts.ATTR_PREFERRED_SAMPLE_RATE: _TTS_SAMPLE_RATE,
|
||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: SAMPLE_CHANNELS,
|
||||
tts.ATTR_PREFERRED_SAMPLE_BYTES: SAMPLE_WIDTH,
|
||||
}
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
|
@ -244,6 +253,76 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
|||
)
|
||||
)
|
||||
|
||||
async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
|
||||
"""Announce media on the satellite.
|
||||
|
||||
Should block until the announcement is done playing.
|
||||
"""
|
||||
assert self._client is not None
|
||||
|
||||
if self._ffmpeg_manager is None:
|
||||
self._ffmpeg_manager = ffmpeg.get_ffmpeg_manager(self.hass)
|
||||
|
||||
if self._played_event_received is None:
|
||||
self._played_event_received = asyncio.Event()
|
||||
|
||||
self._played_event_received.clear()
|
||||
await self._client.write_event(
|
||||
AudioStart(
|
||||
rate=_TTS_SAMPLE_RATE,
|
||||
width=SAMPLE_WIDTH,
|
||||
channels=SAMPLE_CHANNELS,
|
||||
timestamp=0,
|
||||
).event()
|
||||
)
|
||||
|
||||
timestamp = 0
|
||||
try:
|
||||
# Use ffmpeg to convert to raw PCM audio with the appropriate format
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
self._ffmpeg_manager.binary,
|
||||
"-i",
|
||||
announcement.media_id,
|
||||
"-f",
|
||||
"s16le",
|
||||
"-ac",
|
||||
str(SAMPLE_CHANNELS),
|
||||
"-ar",
|
||||
str(_TTS_SAMPLE_RATE),
|
||||
"-nostats",
|
||||
"pipe:",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
close_fds=False, # use posix_spawn in CPython < 3.13
|
||||
)
|
||||
assert proc.stdout is not None
|
||||
while True:
|
||||
chunk_bytes = await proc.stdout.read(_ANNOUNCE_CHUNK_BYTES)
|
||||
if not chunk_bytes:
|
||||
break
|
||||
|
||||
chunk = AudioChunk(
|
||||
rate=_TTS_SAMPLE_RATE,
|
||||
width=SAMPLE_WIDTH,
|
||||
channels=SAMPLE_CHANNELS,
|
||||
audio=chunk_bytes,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
await self._client.write_event(chunk.event())
|
||||
|
||||
timestamp += chunk.milliseconds
|
||||
finally:
|
||||
await self._client.write_event(AudioStop().event())
|
||||
if timestamp > 0:
|
||||
# Wait the length of the audio or until we receive a played event
|
||||
audio_seconds = timestamp / 1000
|
||||
try:
|
||||
async with asyncio.timeout(audio_seconds + 0.5):
|
||||
await self._played_event_received.wait()
|
||||
except TimeoutError:
|
||||
# Older satellite clients will wait longer than necessary
|
||||
_LOGGER.debug("Did not receive played event for announcement")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def start_satellite(self) -> None:
|
||||
|
@ -511,6 +590,9 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
|||
elif Played.is_type(client_event.type):
|
||||
# TTS response has finished playing on satellite
|
||||
self.tts_response_finished()
|
||||
|
||||
if self._played_event_received is not None:
|
||||
self._played_event_received.set()
|
||||
else:
|
||||
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
|
||||
|
||||
|
|
|
@ -7,7 +7,8 @@
|
|||
"assist_satellite",
|
||||
"assist_pipeline",
|
||||
"intent",
|
||||
"conversation"
|
||||
"conversation",
|
||||
"ffmpeg"
|
||||
],
|
||||
"documentation": "https://www.home-assistant.io/integrations/wyoming",
|
||||
"integration_type": "service",
|
||||
|
|
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import io
|
||||
import tempfile
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
import wave
|
||||
|
@ -17,17 +18,18 @@ from wyoming.info import Info
|
|||
from wyoming.ping import Ping, Pong
|
||||
from wyoming.pipeline import PipelineStage, RunPipeline
|
||||
from wyoming.satellite import RunSatellite
|
||||
from wyoming.snd import Played
|
||||
from wyoming.timer import TimerCancelled, TimerFinished, TimerStarted, TimerUpdated
|
||||
from wyoming.tts import Synthesize
|
||||
from wyoming.vad import VoiceStarted, VoiceStopped
|
||||
from wyoming.wake import Detect, Detection
|
||||
|
||||
from homeassistant.components import assist_pipeline, wyoming
|
||||
from homeassistant.components import assist_pipeline, assist_satellite, wyoming
|
||||
from homeassistant.components.wyoming.assist_satellite import WyomingAssistSatellite
|
||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||
from homeassistant.const import STATE_ON
|
||||
from homeassistant.core import HomeAssistant, State
|
||||
from homeassistant.helpers import intent as intent_helper
|
||||
from homeassistant.helpers import entity_registry as er, intent as intent_helper
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import SATELLITE_INFO, WAKE_WORD_INFO, MockAsyncTcpClient
|
||||
|
@ -65,7 +67,7 @@ def get_test_wav() -> bytes:
|
|||
wav_file.setnchannels(1)
|
||||
|
||||
# Single frame
|
||||
wav_file.writeframes(b"123")
|
||||
wav_file.writeframes(b"1234")
|
||||
|
||||
return wav_io.getvalue()
|
||||
|
||||
|
@ -73,10 +75,15 @@ def get_test_wav() -> bytes:
|
|||
class SatelliteAsyncTcpClient(MockAsyncTcpClient):
|
||||
"""Satellite AsyncTcpClient."""
|
||||
|
||||
def __init__(self, responses: list[Event]) -> None:
|
||||
def __init__(
|
||||
self, responses: list[Event], block_until_inject: bool = False
|
||||
) -> None:
|
||||
"""Initialize client."""
|
||||
super().__init__(responses)
|
||||
|
||||
self.block_until_inject = block_until_inject
|
||||
self._responses_ready = asyncio.Event()
|
||||
|
||||
self.connect_event = asyncio.Event()
|
||||
self.run_satellite_event = asyncio.Event()
|
||||
self.detect_event = asyncio.Event()
|
||||
|
@ -188,6 +195,9 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
|
|||
|
||||
async def read_event(self) -> Event | None:
|
||||
"""Receive."""
|
||||
if self.block_until_inject and (not self.responses):
|
||||
await self._responses_ready.wait()
|
||||
|
||||
event = await super().read_event()
|
||||
|
||||
# Keep sending audio chunks instead of None
|
||||
|
@ -196,6 +206,7 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
|
|||
def inject_event(self, event: Event) -> None:
|
||||
"""Put an event in as the next response."""
|
||||
self.responses = [event, *self.responses]
|
||||
self._responses_ready.set()
|
||||
|
||||
|
||||
async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
||||
|
@ -416,7 +427,7 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
|||
assert mock_client.tts_audio_chunk.rate == 22050
|
||||
assert mock_client.tts_audio_chunk.width == 2
|
||||
assert mock_client.tts_audio_chunk.channels == 1
|
||||
assert mock_client.tts_audio_chunk.audio == b"123"
|
||||
assert mock_client.tts_audio_chunk.audio == b"1234"
|
||||
|
||||
# Pipeline finished
|
||||
pipeline_event_callback(
|
||||
|
@ -1283,3 +1294,85 @@ async def test_timers(hass: HomeAssistant) -> None:
|
|||
timer_finished = mock_client.timer_finished
|
||||
assert timer_finished is not None
|
||||
assert timer_finished.id == timer_started.id
|
||||
|
||||
|
||||
async def test_announce(
|
||||
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
||||
) -> None:
|
||||
"""Test announce on satellite."""
|
||||
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
||||
|
||||
def async_process_play_media_url(hass: HomeAssistant, media_id: str) -> str:
|
||||
# Don't create a URL
|
||||
return media_id
|
||||
|
||||
with (
|
||||
tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as temp_wav_file,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(responses=[], block_until_inject=True),
|
||||
) as mock_client,
|
||||
patch(
|
||||
"homeassistant.components.assist_satellite.entity.async_process_play_media_url",
|
||||
new=async_process_play_media_url,
|
||||
),
|
||||
):
|
||||
# Use test WAV data for media
|
||||
with wave.open(temp_wav_file.name, "wb") as wav_file:
|
||||
wav_file.setframerate(22050)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.writeframes(bytes(22050 * 2)) # 1 sec
|
||||
|
||||
temp_wav_file.seek(0)
|
||||
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
|
||||
assert device is not None
|
||||
|
||||
satellite_entry = next(
|
||||
(
|
||||
maybe_entry
|
||||
for maybe_entry in er.async_entries_for_device(
|
||||
entity_registry, device.device_id
|
||||
)
|
||||
if maybe_entry.domain == assist_satellite.DOMAIN
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert satellite_entry is not None
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.connect_event.wait()
|
||||
await mock_client.run_satellite_event.wait()
|
||||
|
||||
announce_task = hass.async_create_background_task(
|
||||
hass.services.async_call(
|
||||
assist_satellite.DOMAIN,
|
||||
"announce",
|
||||
{
|
||||
"entity_id": satellite_entry.entity_id,
|
||||
"media_id": temp_wav_file.name,
|
||||
},
|
||||
blocking=True,
|
||||
),
|
||||
"wyoming_satellite_announce",
|
||||
)
|
||||
|
||||
# Wait for audio to come from ffmpeg
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.tts_audio_start_event.wait()
|
||||
await mock_client.tts_audio_chunk_event.wait()
|
||||
await mock_client.tts_audio_stop_event.wait()
|
||||
|
||||
# Stop announcement from blocking
|
||||
mock_client.inject_event(Played().event())
|
||||
await announce_task
|
||||
|
||||
# Stop the satellite
|
||||
await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
|
Loading…
Reference in New Issue