core/homeassistant/components/wyoming/satellite.py

423 lines
16 KiB
Python

"""Support for Wyoming satellite services."""
import asyncio
from collections.abc import AsyncGenerator
import io
import logging
from typing import Final
import wave
from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop
from wyoming.client import AsyncTcpClient
from wyoming.error import Error
from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import RunSatellite
from wyoming.tts import Synthesize, SynthesizeVoice
from wyoming.vad import VoiceStarted, VoiceStopped
from wyoming.wake import Detect, Detection
from homeassistant.components import assist_pipeline, stt, tts
from homeassistant.components.assist_pipeline import select as pipeline_select
from homeassistant.core import Context, HomeAssistant
from .const import DOMAIN
from .data import WyomingService
from .devices import SatelliteDevice
_LOGGER = logging.getLogger()
_SAMPLES_PER_CHUNK: Final = 1024
_RECONNECT_SECONDS: Final = 10
_RESTART_SECONDS: Final = 3
# Wyoming stage -> Assist stage
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
PipelineStage.WAKE: assist_pipeline.PipelineStage.WAKE_WORD,
PipelineStage.ASR: assist_pipeline.PipelineStage.STT,
PipelineStage.HANDLE: assist_pipeline.PipelineStage.INTENT,
PipelineStage.TTS: assist_pipeline.PipelineStage.TTS,
}
class WyomingSatellite:
"""Remove voice satellite running the Wyoming protocol."""
def __init__(
self, hass: HomeAssistant, service: WyomingService, device: SatelliteDevice
) -> None:
"""Initialize satellite."""
self.hass = hass
self.service = service
self.device = device
self.is_running = True
self._client: AsyncTcpClient | None = None
self._chunk_converter = AudioChunkConverter(rate=16000, width=2, channels=1)
self._is_pipeline_running = False
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
self._pipeline_id: str | None = None
self._muted_changed_event = asyncio.Event()
self.device.set_is_muted_listener(self._muted_changed)
self.device.set_pipeline_listener(self._pipeline_changed)
self.device.set_audio_settings_listener(self._audio_settings_changed)
async def run(self) -> None:
"""Run and maintain a connection to satellite."""
_LOGGER.debug("Running satellite task")
try:
while self.is_running:
try:
# Check if satellite has been muted
while self.device.is_muted:
await self.on_muted()
if not self.is_running:
# Satellite was stopped while waiting to be unmuted
return
# Connect and run pipeline loop
await self._run_once()
except asyncio.CancelledError:
raise
except Exception: # pylint: disable=broad-exception-caught
await self.on_restart()
finally:
# Ensure sensor is off
self.device.set_is_active(False)
await self.on_stopped()
def stop(self) -> None:
"""Signal satellite task to stop running."""
self.is_running = False
# Unblock waiting for unmuted
self._muted_changed_event.set()
async def on_restart(self) -> None:
"""Block until pipeline loop will be restarted."""
_LOGGER.warning(
"Unexpected error running satellite. Restarting in %s second(s)",
_RECONNECT_SECONDS,
)
await asyncio.sleep(_RESTART_SECONDS)
async def on_reconnect(self) -> None:
"""Block until a reconnection attempt should be made."""
_LOGGER.debug(
"Failed to connect to satellite. Reconnecting in %s second(s)",
_RECONNECT_SECONDS,
)
await asyncio.sleep(_RECONNECT_SECONDS)
async def on_muted(self) -> None:
"""Block until device may be unmated again."""
await self._muted_changed_event.wait()
async def on_stopped(self) -> None:
"""Run when run() has fully stopped."""
_LOGGER.debug("Satellite task stopped")
# -------------------------------------------------------------------------
def _muted_changed(self) -> None:
"""Run when device muted status changes."""
if self.device.is_muted:
# Cancel any running pipeline
self._audio_queue.put_nowait(None)
self._muted_changed_event.set()
self._muted_changed_event.clear()
def _pipeline_changed(self) -> None:
"""Run when device pipeline changes."""
# Cancel any running pipeline
self._audio_queue.put_nowait(None)
def _audio_settings_changed(self) -> None:
"""Run when device audio settings."""
# Cancel any running pipeline
self._audio_queue.put_nowait(None)
async def _run_once(self) -> None:
"""Run pipelines until an error occurs."""
self.device.set_is_active(False)
while self.is_running and (not self.device.is_muted):
try:
await self._connect()
break
except ConnectionError:
await self.on_reconnect()
assert self._client is not None
_LOGGER.debug("Connected to satellite")
if (not self.is_running) or self.device.is_muted:
# Run was cancelled or satellite was disabled during connection
return
# Tell satellite that we're ready
await self._client.write_event(RunSatellite().event())
# Wait until we get RunPipeline event
run_pipeline: RunPipeline | None = None
while self.is_running and (not self.device.is_muted):
run_event = await self._client.read_event()
if run_event is None:
raise ConnectionResetError("Satellite disconnected")
if RunPipeline.is_type(run_event.type):
run_pipeline = RunPipeline.from_event(run_event)
break
_LOGGER.debug("Unexpected event from satellite: %s", run_event)
assert run_pipeline is not None
_LOGGER.debug("Received run information: %s", run_pipeline)
if (not self.is_running) or self.device.is_muted:
# Run was cancelled or satellite was disabled while waiting for
# RunPipeline event.
return
start_stage = _STAGES.get(run_pipeline.start_stage)
end_stage = _STAGES.get(run_pipeline.end_stage)
if start_stage is None:
raise ValueError(f"Invalid start stage: {start_stage}")
if end_stage is None:
raise ValueError(f"Invalid end stage: {end_stage}")
# Each loop is a pipeline run
while self.is_running and (not self.device.is_muted):
# Use select to get pipeline each time in case it's changed
pipeline_id = pipeline_select.get_chosen_pipeline(
self.hass,
DOMAIN,
self.device.satellite_id,
)
pipeline = assist_pipeline.async_get_pipeline(self.hass, pipeline_id)
assert pipeline is not None
# We will push audio in through a queue
self._audio_queue = asyncio.Queue()
stt_stream = self._stt_stream()
# Start pipeline running
_LOGGER.debug(
"Starting pipeline %s from %s to %s",
pipeline.name,
start_stage,
end_stage,
)
self._is_pipeline_running = True
_pipeline_task = asyncio.create_task(
assist_pipeline.async_pipeline_from_audio_stream(
self.hass,
context=Context(),
event_callback=self._event_callback,
stt_metadata=stt.SpeechMetadata(
language=pipeline.language,
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=stt_stream,
start_stage=start_stage,
end_stage=end_stage,
tts_audio_output="wav",
pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(
noise_suppression_level=self.device.noise_suppression_level,
auto_gain_dbfs=self.device.auto_gain,
volume_multiplier=self.device.volume_multiplier,
),
device_id=self.device.device_id,
)
)
# Run until pipeline is complete or cancelled with an empty audio chunk
while self._is_pipeline_running:
client_event = await self._client.read_event()
if client_event is None:
raise ConnectionResetError("Satellite disconnected")
if AudioChunk.is_type(client_event.type):
# Microphone audio
chunk = AudioChunk.from_event(client_event)
chunk = self._chunk_converter.convert(chunk)
self._audio_queue.put_nowait(chunk.audio)
elif AudioStop.is_type(client_event.type):
# Stop pipeline
_LOGGER.debug("Client requested pipeline to stop")
self._audio_queue.put_nowait(b"")
break
else:
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
# Ensure task finishes
await _pipeline_task
_LOGGER.debug("Pipeline finished")
def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None:
"""Translate pipeline events into Wyoming events."""
assert self._client is not None
if event.type == assist_pipeline.PipelineEventType.RUN_END:
# Pipeline run is complete
self._is_pipeline_running = False
self.device.set_is_active(False)
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
self.hass.add_job(self._client.write_event(Detect().event()))
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END:
# Wake word detection
self.device.set_is_active(True)
# Inform client of wake word detection
if event.data and (wake_word_output := event.data.get("wake_word_output")):
detection = Detection(
name=wake_word_output["wake_word_id"],
timestamp=wake_word_output.get("timestamp"),
)
self.hass.add_job(self._client.write_event(detection.event()))
elif event.type == assist_pipeline.PipelineEventType.STT_START:
# Speech-to-text
self.device.set_is_active(True)
if event.data:
self.hass.add_job(
self._client.write_event(
Transcribe(language=event.data["metadata"]["language"]).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START:
# User started speaking
if event.data:
self.hass.add_job(
self._client.write_event(
VoiceStarted(timestamp=event.data["timestamp"]).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END:
# User stopped speaking
if event.data:
self.hass.add_job(
self._client.write_event(
VoiceStopped(timestamp=event.data["timestamp"]).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.STT_END:
# Speech-to-text transcript
if event.data:
# Inform client of transript
stt_text = event.data["stt_output"]["text"]
self.hass.add_job(
self._client.write_event(Transcript(text=stt_text).event())
)
elif event.type == assist_pipeline.PipelineEventType.TTS_START:
# Text-to-speech text
if event.data:
# Inform client of text
self.hass.add_job(
self._client.write_event(
Synthesize(
text=event.data["tts_input"],
voice=SynthesizeVoice(
name=event.data.get("voice"),
language=event.data.get("language"),
),
).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.TTS_END:
# TTS stream
if event.data and (tts_output := event.data["tts_output"]):
media_id = tts_output["media_id"]
self.hass.add_job(self._stream_tts(media_id))
elif event.type == assist_pipeline.PipelineEventType.ERROR:
# Pipeline error
if event.data:
self.hass.add_job(
self._client.write_event(
Error(
text=event.data["message"], code=event.data["code"]
).event()
)
)
async def _connect(self) -> None:
"""Connect to satellite over TCP."""
await self._disconnect()
_LOGGER.debug(
"Connecting to satellite at %s:%s", self.service.host, self.service.port
)
self._client = AsyncTcpClient(self.service.host, self.service.port)
await self._client.connect()
async def _disconnect(self) -> None:
"""Disconnect if satellite is currently connected."""
if self._client is None:
return
_LOGGER.debug("Disconnecting from satellite")
await self._client.disconnect()
self._client = None
async def _stream_tts(self, media_id: str) -> None:
"""Stream TTS WAV audio to satellite in chunks."""
assert self._client is not None
extension, data = await tts.async_get_media_source_audio(self.hass, media_id)
if extension != "wav":
raise ValueError(f"Cannot stream audio format to satellite: {extension}")
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()
_LOGGER.debug("Streaming %s TTS sample(s)", wav_file.getnframes())
timestamp = 0
await self._client.write_event(
AudioStart(
rate=sample_rate,
width=sample_width,
channels=sample_channels,
timestamp=timestamp,
).event()
)
# Stream audio chunks
while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK):
chunk = AudioChunk(
rate=sample_rate,
width=sample_width,
channels=sample_channels,
audio=audio_bytes,
timestamp=timestamp,
)
await self._client.write_event(chunk.event())
timestamp += chunk.seconds
await self._client.write_event(AudioStop(timestamp=timestamp).event())
_LOGGER.debug("TTS streaming complete")
async def _stt_stream(self) -> AsyncGenerator[bytes, None]:
"""Yield audio chunks from a queue."""
is_first_chunk = True
while chunk := await self._audio_queue.get():
if is_first_chunk:
is_first_chunk = False
_LOGGER.debug("Receiving audio from satellite")
yield chunk