Add Wyoming satellite announce (#138221)
* Add Wyoming satellite announce * Initialize when necessarypull/138240/head
parent
f83c8de8d3
commit
6bc6111771
|
@ -24,18 +24,20 @@ from wyoming.tts import Synthesize, SynthesizeVoice
|
||||||
from wyoming.vad import VoiceStarted, VoiceStopped
|
from wyoming.vad import VoiceStarted, VoiceStopped
|
||||||
from wyoming.wake import Detect, Detection
|
from wyoming.wake import Detect, Detection
|
||||||
|
|
||||||
from homeassistant.components import assist_pipeline, intent, tts
|
from homeassistant.components import assist_pipeline, ffmpeg, intent, tts
|
||||||
from homeassistant.components.assist_pipeline import PipelineEvent
|
from homeassistant.components.assist_pipeline import PipelineEvent
|
||||||
from homeassistant.components.assist_satellite import (
|
from homeassistant.components.assist_satellite import (
|
||||||
|
AssistSatelliteAnnouncement,
|
||||||
AssistSatelliteConfiguration,
|
AssistSatelliteConfiguration,
|
||||||
AssistSatelliteEntity,
|
AssistSatelliteEntity,
|
||||||
AssistSatelliteEntityDescription,
|
AssistSatelliteEntityDescription,
|
||||||
|
AssistSatelliteEntityFeature,
|
||||||
)
|
)
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddConfigEntryEntitiesCallback
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN, SAMPLE_CHANNELS, SAMPLE_WIDTH
|
||||||
from .data import WyomingService
|
from .data import WyomingService
|
||||||
from .devices import SatelliteDevice
|
from .devices import SatelliteDevice
|
||||||
from .entity import WyomingSatelliteEntity
|
from .entity import WyomingSatelliteEntity
|
||||||
|
@ -49,6 +51,8 @@ _RESTART_SECONDS: Final = 3
|
||||||
_PING_TIMEOUT: Final = 5
|
_PING_TIMEOUT: Final = 5
|
||||||
_PING_SEND_DELAY: Final = 2
|
_PING_SEND_DELAY: Final = 2
|
||||||
_PIPELINE_FINISH_TIMEOUT: Final = 1
|
_PIPELINE_FINISH_TIMEOUT: Final = 1
|
||||||
|
_TTS_SAMPLE_RATE: Final = 22050
|
||||||
|
_ANNOUNCE_CHUNK_BYTES: Final = 2048 # 1024 samples
|
||||||
|
|
||||||
# Wyoming stage -> Assist stage
|
# Wyoming stage -> Assist stage
|
||||||
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
|
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
|
||||||
|
@ -83,6 +87,7 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||||
entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
|
entity_description = AssistSatelliteEntityDescription(key="assist_satellite")
|
||||||
_attr_translation_key = "assist_satellite"
|
_attr_translation_key = "assist_satellite"
|
||||||
_attr_name = None
|
_attr_name = None
|
||||||
|
_attr_supported_features = AssistSatelliteEntityFeature.ANNOUNCE
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -116,6 +121,10 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||||
self.device.set_pipeline_listener(self._pipeline_changed)
|
self.device.set_pipeline_listener(self._pipeline_changed)
|
||||||
self.device.set_audio_settings_listener(self._audio_settings_changed)
|
self.device.set_audio_settings_listener(self._audio_settings_changed)
|
||||||
|
|
||||||
|
# For announcements
|
||||||
|
self._ffmpeg_manager: ffmpeg.FFmpegManager | None = None
|
||||||
|
self._played_event_received: asyncio.Event | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def pipeline_entity_id(self) -> str | None:
|
def pipeline_entity_id(self) -> str | None:
|
||||||
"""Return the entity ID of the pipeline to use for the next conversation."""
|
"""Return the entity ID of the pipeline to use for the next conversation."""
|
||||||
|
@ -131,9 +140,9 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||||
"""Options passed for text-to-speech."""
|
"""Options passed for text-to-speech."""
|
||||||
return {
|
return {
|
||||||
tts.ATTR_PREFERRED_FORMAT: "wav",
|
tts.ATTR_PREFERRED_FORMAT: "wav",
|
||||||
tts.ATTR_PREFERRED_SAMPLE_RATE: 16000,
|
tts.ATTR_PREFERRED_SAMPLE_RATE: _TTS_SAMPLE_RATE,
|
||||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
|
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: SAMPLE_CHANNELS,
|
||||||
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
|
tts.ATTR_PREFERRED_SAMPLE_BYTES: SAMPLE_WIDTH,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def async_added_to_hass(self) -> None:
|
async def async_added_to_hass(self) -> None:
|
||||||
|
@ -244,6 +253,76 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
|
||||||
|
"""Announce media on the satellite.
|
||||||
|
|
||||||
|
Should block until the announcement is done playing.
|
||||||
|
"""
|
||||||
|
assert self._client is not None
|
||||||
|
|
||||||
|
if self._ffmpeg_manager is None:
|
||||||
|
self._ffmpeg_manager = ffmpeg.get_ffmpeg_manager(self.hass)
|
||||||
|
|
||||||
|
if self._played_event_received is None:
|
||||||
|
self._played_event_received = asyncio.Event()
|
||||||
|
|
||||||
|
self._played_event_received.clear()
|
||||||
|
await self._client.write_event(
|
||||||
|
AudioStart(
|
||||||
|
rate=_TTS_SAMPLE_RATE,
|
||||||
|
width=SAMPLE_WIDTH,
|
||||||
|
channels=SAMPLE_CHANNELS,
|
||||||
|
timestamp=0,
|
||||||
|
).event()
|
||||||
|
)
|
||||||
|
|
||||||
|
timestamp = 0
|
||||||
|
try:
|
||||||
|
# Use ffmpeg to convert to raw PCM audio with the appropriate format
|
||||||
|
proc = await asyncio.create_subprocess_exec(
|
||||||
|
self._ffmpeg_manager.binary,
|
||||||
|
"-i",
|
||||||
|
announcement.media_id,
|
||||||
|
"-f",
|
||||||
|
"s16le",
|
||||||
|
"-ac",
|
||||||
|
str(SAMPLE_CHANNELS),
|
||||||
|
"-ar",
|
||||||
|
str(_TTS_SAMPLE_RATE),
|
||||||
|
"-nostats",
|
||||||
|
"pipe:",
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
close_fds=False, # use posix_spawn in CPython < 3.13
|
||||||
|
)
|
||||||
|
assert proc.stdout is not None
|
||||||
|
while True:
|
||||||
|
chunk_bytes = await proc.stdout.read(_ANNOUNCE_CHUNK_BYTES)
|
||||||
|
if not chunk_bytes:
|
||||||
|
break
|
||||||
|
|
||||||
|
chunk = AudioChunk(
|
||||||
|
rate=_TTS_SAMPLE_RATE,
|
||||||
|
width=SAMPLE_WIDTH,
|
||||||
|
channels=SAMPLE_CHANNELS,
|
||||||
|
audio=chunk_bytes,
|
||||||
|
timestamp=timestamp,
|
||||||
|
)
|
||||||
|
await self._client.write_event(chunk.event())
|
||||||
|
|
||||||
|
timestamp += chunk.milliseconds
|
||||||
|
finally:
|
||||||
|
await self._client.write_event(AudioStop().event())
|
||||||
|
if timestamp > 0:
|
||||||
|
# Wait the length of the audio or until we receive a played event
|
||||||
|
audio_seconds = timestamp / 1000
|
||||||
|
try:
|
||||||
|
async with asyncio.timeout(audio_seconds + 0.5):
|
||||||
|
await self._played_event_received.wait()
|
||||||
|
except TimeoutError:
|
||||||
|
# Older satellite clients will wait longer than necessary
|
||||||
|
_LOGGER.debug("Did not receive played event for announcement")
|
||||||
|
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
def start_satellite(self) -> None:
|
def start_satellite(self) -> None:
|
||||||
|
@ -511,6 +590,9 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||||
elif Played.is_type(client_event.type):
|
elif Played.is_type(client_event.type):
|
||||||
# TTS response has finished playing on satellite
|
# TTS response has finished playing on satellite
|
||||||
self.tts_response_finished()
|
self.tts_response_finished()
|
||||||
|
|
||||||
|
if self._played_event_received is not None:
|
||||||
|
self._played_event_received.set()
|
||||||
else:
|
else:
|
||||||
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
|
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,8 @@
|
||||||
"assist_satellite",
|
"assist_satellite",
|
||||||
"assist_pipeline",
|
"assist_pipeline",
|
||||||
"intent",
|
"intent",
|
||||||
"conversation"
|
"conversation",
|
||||||
|
"ffmpeg"
|
||||||
],
|
],
|
||||||
"documentation": "https://www.home-assistant.io/integrations/wyoming",
|
"documentation": "https://www.home-assistant.io/integrations/wyoming",
|
||||||
"integration_type": "service",
|
"integration_type": "service",
|
||||||
|
|
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
import io
|
import io
|
||||||
|
import tempfile
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import wave
|
import wave
|
||||||
|
@ -17,17 +18,18 @@ from wyoming.info import Info
|
||||||
from wyoming.ping import Ping, Pong
|
from wyoming.ping import Ping, Pong
|
||||||
from wyoming.pipeline import PipelineStage, RunPipeline
|
from wyoming.pipeline import PipelineStage, RunPipeline
|
||||||
from wyoming.satellite import RunSatellite
|
from wyoming.satellite import RunSatellite
|
||||||
|
from wyoming.snd import Played
|
||||||
from wyoming.timer import TimerCancelled, TimerFinished, TimerStarted, TimerUpdated
|
from wyoming.timer import TimerCancelled, TimerFinished, TimerStarted, TimerUpdated
|
||||||
from wyoming.tts import Synthesize
|
from wyoming.tts import Synthesize
|
||||||
from wyoming.vad import VoiceStarted, VoiceStopped
|
from wyoming.vad import VoiceStarted, VoiceStopped
|
||||||
from wyoming.wake import Detect, Detection
|
from wyoming.wake import Detect, Detection
|
||||||
|
|
||||||
from homeassistant.components import assist_pipeline, wyoming
|
from homeassistant.components import assist_pipeline, assist_satellite, wyoming
|
||||||
from homeassistant.components.wyoming.assist_satellite import WyomingAssistSatellite
|
from homeassistant.components.wyoming.assist_satellite import WyomingAssistSatellite
|
||||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||||
from homeassistant.const import STATE_ON
|
from homeassistant.const import STATE_ON
|
||||||
from homeassistant.core import HomeAssistant, State
|
from homeassistant.core import HomeAssistant, State
|
||||||
from homeassistant.helpers import intent as intent_helper
|
from homeassistant.helpers import entity_registry as er, intent as intent_helper
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from . import SATELLITE_INFO, WAKE_WORD_INFO, MockAsyncTcpClient
|
from . import SATELLITE_INFO, WAKE_WORD_INFO, MockAsyncTcpClient
|
||||||
|
@ -65,7 +67,7 @@ def get_test_wav() -> bytes:
|
||||||
wav_file.setnchannels(1)
|
wav_file.setnchannels(1)
|
||||||
|
|
||||||
# Single frame
|
# Single frame
|
||||||
wav_file.writeframes(b"123")
|
wav_file.writeframes(b"1234")
|
||||||
|
|
||||||
return wav_io.getvalue()
|
return wav_io.getvalue()
|
||||||
|
|
||||||
|
@ -73,10 +75,15 @@ def get_test_wav() -> bytes:
|
||||||
class SatelliteAsyncTcpClient(MockAsyncTcpClient):
|
class SatelliteAsyncTcpClient(MockAsyncTcpClient):
|
||||||
"""Satellite AsyncTcpClient."""
|
"""Satellite AsyncTcpClient."""
|
||||||
|
|
||||||
def __init__(self, responses: list[Event]) -> None:
|
def __init__(
|
||||||
|
self, responses: list[Event], block_until_inject: bool = False
|
||||||
|
) -> None:
|
||||||
"""Initialize client."""
|
"""Initialize client."""
|
||||||
super().__init__(responses)
|
super().__init__(responses)
|
||||||
|
|
||||||
|
self.block_until_inject = block_until_inject
|
||||||
|
self._responses_ready = asyncio.Event()
|
||||||
|
|
||||||
self.connect_event = asyncio.Event()
|
self.connect_event = asyncio.Event()
|
||||||
self.run_satellite_event = asyncio.Event()
|
self.run_satellite_event = asyncio.Event()
|
||||||
self.detect_event = asyncio.Event()
|
self.detect_event = asyncio.Event()
|
||||||
|
@ -188,6 +195,9 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
|
||||||
|
|
||||||
async def read_event(self) -> Event | None:
|
async def read_event(self) -> Event | None:
|
||||||
"""Receive."""
|
"""Receive."""
|
||||||
|
if self.block_until_inject and (not self.responses):
|
||||||
|
await self._responses_ready.wait()
|
||||||
|
|
||||||
event = await super().read_event()
|
event = await super().read_event()
|
||||||
|
|
||||||
# Keep sending audio chunks instead of None
|
# Keep sending audio chunks instead of None
|
||||||
|
@ -196,6 +206,7 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
|
||||||
def inject_event(self, event: Event) -> None:
|
def inject_event(self, event: Event) -> None:
|
||||||
"""Put an event in as the next response."""
|
"""Put an event in as the next response."""
|
||||||
self.responses = [event, *self.responses]
|
self.responses = [event, *self.responses]
|
||||||
|
self._responses_ready.set()
|
||||||
|
|
||||||
|
|
||||||
async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
||||||
|
@ -416,7 +427,7 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
||||||
assert mock_client.tts_audio_chunk.rate == 22050
|
assert mock_client.tts_audio_chunk.rate == 22050
|
||||||
assert mock_client.tts_audio_chunk.width == 2
|
assert mock_client.tts_audio_chunk.width == 2
|
||||||
assert mock_client.tts_audio_chunk.channels == 1
|
assert mock_client.tts_audio_chunk.channels == 1
|
||||||
assert mock_client.tts_audio_chunk.audio == b"123"
|
assert mock_client.tts_audio_chunk.audio == b"1234"
|
||||||
|
|
||||||
# Pipeline finished
|
# Pipeline finished
|
||||||
pipeline_event_callback(
|
pipeline_event_callback(
|
||||||
|
@ -1283,3 +1294,85 @@ async def test_timers(hass: HomeAssistant) -> None:
|
||||||
timer_finished = mock_client.timer_finished
|
timer_finished = mock_client.timer_finished
|
||||||
assert timer_finished is not None
|
assert timer_finished is not None
|
||||||
assert timer_finished.id == timer_started.id
|
assert timer_finished.id == timer_started.id
|
||||||
|
|
||||||
|
|
||||||
|
async def test_announce(
|
||||||
|
hass: HomeAssistant, entity_registry: er.EntityRegistry
|
||||||
|
) -> None:
|
||||||
|
"""Test announce on satellite."""
|
||||||
|
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
||||||
|
|
||||||
|
def async_process_play_media_url(hass: HomeAssistant, media_id: str) -> str:
|
||||||
|
# Don't create a URL
|
||||||
|
return media_id
|
||||||
|
|
||||||
|
with (
|
||||||
|
tempfile.NamedTemporaryFile(mode="wb+", suffix=".wav") as temp_wav_file,
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
|
return_value=SATELLITE_INFO,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.wyoming.assist_satellite.AsyncTcpClient",
|
||||||
|
SatelliteAsyncTcpClient(responses=[], block_until_inject=True),
|
||||||
|
) as mock_client,
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_process_play_media_url",
|
||||||
|
new=async_process_play_media_url,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
# Use test WAV data for media
|
||||||
|
with wave.open(temp_wav_file.name, "wb") as wav_file:
|
||||||
|
wav_file.setframerate(22050)
|
||||||
|
wav_file.setsampwidth(2)
|
||||||
|
wav_file.setnchannels(1)
|
||||||
|
wav_file.writeframes(bytes(22050 * 2)) # 1 sec
|
||||||
|
|
||||||
|
temp_wav_file.seek(0)
|
||||||
|
|
||||||
|
entry = await setup_config_entry(hass)
|
||||||
|
device: SatelliteDevice = hass.data[wyoming.DOMAIN][entry.entry_id].device
|
||||||
|
assert device is not None
|
||||||
|
|
||||||
|
satellite_entry = next(
|
||||||
|
(
|
||||||
|
maybe_entry
|
||||||
|
for maybe_entry in er.async_entries_for_device(
|
||||||
|
entity_registry, device.device_id
|
||||||
|
)
|
||||||
|
if maybe_entry.domain == assist_satellite.DOMAIN
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert satellite_entry is not None
|
||||||
|
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await mock_client.connect_event.wait()
|
||||||
|
await mock_client.run_satellite_event.wait()
|
||||||
|
|
||||||
|
announce_task = hass.async_create_background_task(
|
||||||
|
hass.services.async_call(
|
||||||
|
assist_satellite.DOMAIN,
|
||||||
|
"announce",
|
||||||
|
{
|
||||||
|
"entity_id": satellite_entry.entity_id,
|
||||||
|
"media_id": temp_wav_file.name,
|
||||||
|
},
|
||||||
|
blocking=True,
|
||||||
|
),
|
||||||
|
"wyoming_satellite_announce",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Wait for audio to come from ffmpeg
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await mock_client.tts_audio_start_event.wait()
|
||||||
|
await mock_client.tts_audio_chunk_event.wait()
|
||||||
|
await mock_client.tts_audio_stop_event.wait()
|
||||||
|
|
||||||
|
# Stop announcement from blocking
|
||||||
|
mock_client.inject_event(Played().event())
|
||||||
|
await announce_task
|
||||||
|
|
||||||
|
# Stop the satellite
|
||||||
|
await hass.config_entries.async_unload(entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
Loading…
Reference in New Issue