423 lines
16 KiB
Python
423 lines
16 KiB
Python
"""Support for Wyoming satellite services."""
|
|
import asyncio
|
|
from collections.abc import AsyncGenerator
|
|
import io
|
|
import logging
|
|
from typing import Final
|
|
import wave
|
|
|
|
from wyoming.asr import Transcribe, Transcript
|
|
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop
|
|
from wyoming.client import AsyncTcpClient
|
|
from wyoming.error import Error
|
|
from wyoming.pipeline import PipelineStage, RunPipeline
|
|
from wyoming.satellite import RunSatellite
|
|
from wyoming.tts import Synthesize, SynthesizeVoice
|
|
from wyoming.vad import VoiceStarted, VoiceStopped
|
|
from wyoming.wake import Detect, Detection
|
|
|
|
from homeassistant.components import assist_pipeline, stt, tts
|
|
from homeassistant.components.assist_pipeline import select as pipeline_select
|
|
from homeassistant.core import Context, HomeAssistant
|
|
|
|
from .const import DOMAIN
|
|
from .data import WyomingService
|
|
from .devices import SatelliteDevice
|
|
|
|
_LOGGER = logging.getLogger()
|
|
|
|
_SAMPLES_PER_CHUNK: Final = 1024
|
|
_RECONNECT_SECONDS: Final = 10
|
|
_RESTART_SECONDS: Final = 3
|
|
|
|
# Wyoming stage -> Assist stage
|
|
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
|
|
PipelineStage.WAKE: assist_pipeline.PipelineStage.WAKE_WORD,
|
|
PipelineStage.ASR: assist_pipeline.PipelineStage.STT,
|
|
PipelineStage.HANDLE: assist_pipeline.PipelineStage.INTENT,
|
|
PipelineStage.TTS: assist_pipeline.PipelineStage.TTS,
|
|
}
|
|
|
|
|
|
class WyomingSatellite:
|
|
"""Remove voice satellite running the Wyoming protocol."""
|
|
|
|
def __init__(
|
|
self, hass: HomeAssistant, service: WyomingService, device: SatelliteDevice
|
|
) -> None:
|
|
"""Initialize satellite."""
|
|
self.hass = hass
|
|
self.service = service
|
|
self.device = device
|
|
self.is_running = True
|
|
|
|
self._client: AsyncTcpClient | None = None
|
|
self._chunk_converter = AudioChunkConverter(rate=16000, width=2, channels=1)
|
|
self._is_pipeline_running = False
|
|
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
|
self._pipeline_id: str | None = None
|
|
self._muted_changed_event = asyncio.Event()
|
|
|
|
self.device.set_is_muted_listener(self._muted_changed)
|
|
self.device.set_pipeline_listener(self._pipeline_changed)
|
|
self.device.set_audio_settings_listener(self._audio_settings_changed)
|
|
|
|
async def run(self) -> None:
|
|
"""Run and maintain a connection to satellite."""
|
|
_LOGGER.debug("Running satellite task")
|
|
|
|
try:
|
|
while self.is_running:
|
|
try:
|
|
# Check if satellite has been muted
|
|
while self.device.is_muted:
|
|
await self.on_muted()
|
|
if not self.is_running:
|
|
# Satellite was stopped while waiting to be unmuted
|
|
return
|
|
|
|
# Connect and run pipeline loop
|
|
await self._run_once()
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception: # pylint: disable=broad-exception-caught
|
|
await self.on_restart()
|
|
finally:
|
|
# Ensure sensor is off
|
|
self.device.set_is_active(False)
|
|
|
|
await self.on_stopped()
|
|
|
|
def stop(self) -> None:
|
|
"""Signal satellite task to stop running."""
|
|
self.is_running = False
|
|
|
|
# Unblock waiting for unmuted
|
|
self._muted_changed_event.set()
|
|
|
|
async def on_restart(self) -> None:
|
|
"""Block until pipeline loop will be restarted."""
|
|
_LOGGER.warning(
|
|
"Unexpected error running satellite. Restarting in %s second(s)",
|
|
_RECONNECT_SECONDS,
|
|
)
|
|
await asyncio.sleep(_RESTART_SECONDS)
|
|
|
|
async def on_reconnect(self) -> None:
|
|
"""Block until a reconnection attempt should be made."""
|
|
_LOGGER.debug(
|
|
"Failed to connect to satellite. Reconnecting in %s second(s)",
|
|
_RECONNECT_SECONDS,
|
|
)
|
|
await asyncio.sleep(_RECONNECT_SECONDS)
|
|
|
|
async def on_muted(self) -> None:
|
|
"""Block until device may be unmated again."""
|
|
await self._muted_changed_event.wait()
|
|
|
|
async def on_stopped(self) -> None:
|
|
"""Run when run() has fully stopped."""
|
|
_LOGGER.debug("Satellite task stopped")
|
|
|
|
# -------------------------------------------------------------------------
|
|
|
|
def _muted_changed(self) -> None:
|
|
"""Run when device muted status changes."""
|
|
if self.device.is_muted:
|
|
# Cancel any running pipeline
|
|
self._audio_queue.put_nowait(None)
|
|
|
|
self._muted_changed_event.set()
|
|
self._muted_changed_event.clear()
|
|
|
|
def _pipeline_changed(self) -> None:
|
|
"""Run when device pipeline changes."""
|
|
|
|
# Cancel any running pipeline
|
|
self._audio_queue.put_nowait(None)
|
|
|
|
def _audio_settings_changed(self) -> None:
|
|
"""Run when device audio settings."""
|
|
|
|
# Cancel any running pipeline
|
|
self._audio_queue.put_nowait(None)
|
|
|
|
async def _run_once(self) -> None:
|
|
"""Run pipelines until an error occurs."""
|
|
self.device.set_is_active(False)
|
|
|
|
while self.is_running and (not self.device.is_muted):
|
|
try:
|
|
await self._connect()
|
|
break
|
|
except ConnectionError:
|
|
await self.on_reconnect()
|
|
|
|
assert self._client is not None
|
|
_LOGGER.debug("Connected to satellite")
|
|
|
|
if (not self.is_running) or self.device.is_muted:
|
|
# Run was cancelled or satellite was disabled during connection
|
|
return
|
|
|
|
# Tell satellite that we're ready
|
|
await self._client.write_event(RunSatellite().event())
|
|
|
|
# Wait until we get RunPipeline event
|
|
run_pipeline: RunPipeline | None = None
|
|
while self.is_running and (not self.device.is_muted):
|
|
run_event = await self._client.read_event()
|
|
if run_event is None:
|
|
raise ConnectionResetError("Satellite disconnected")
|
|
|
|
if RunPipeline.is_type(run_event.type):
|
|
run_pipeline = RunPipeline.from_event(run_event)
|
|
break
|
|
|
|
_LOGGER.debug("Unexpected event from satellite: %s", run_event)
|
|
|
|
assert run_pipeline is not None
|
|
_LOGGER.debug("Received run information: %s", run_pipeline)
|
|
|
|
if (not self.is_running) or self.device.is_muted:
|
|
# Run was cancelled or satellite was disabled while waiting for
|
|
# RunPipeline event.
|
|
return
|
|
|
|
start_stage = _STAGES.get(run_pipeline.start_stage)
|
|
end_stage = _STAGES.get(run_pipeline.end_stage)
|
|
|
|
if start_stage is None:
|
|
raise ValueError(f"Invalid start stage: {start_stage}")
|
|
|
|
if end_stage is None:
|
|
raise ValueError(f"Invalid end stage: {end_stage}")
|
|
|
|
# Each loop is a pipeline run
|
|
while self.is_running and (not self.device.is_muted):
|
|
# Use select to get pipeline each time in case it's changed
|
|
pipeline_id = pipeline_select.get_chosen_pipeline(
|
|
self.hass,
|
|
DOMAIN,
|
|
self.device.satellite_id,
|
|
)
|
|
pipeline = assist_pipeline.async_get_pipeline(self.hass, pipeline_id)
|
|
assert pipeline is not None
|
|
|
|
# We will push audio in through a queue
|
|
self._audio_queue = asyncio.Queue()
|
|
stt_stream = self._stt_stream()
|
|
|
|
# Start pipeline running
|
|
_LOGGER.debug(
|
|
"Starting pipeline %s from %s to %s",
|
|
pipeline.name,
|
|
start_stage,
|
|
end_stage,
|
|
)
|
|
self._is_pipeline_running = True
|
|
_pipeline_task = asyncio.create_task(
|
|
assist_pipeline.async_pipeline_from_audio_stream(
|
|
self.hass,
|
|
context=Context(),
|
|
event_callback=self._event_callback,
|
|
stt_metadata=stt.SpeechMetadata(
|
|
language=pipeline.language,
|
|
format=stt.AudioFormats.WAV,
|
|
codec=stt.AudioCodecs.PCM,
|
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
|
),
|
|
stt_stream=stt_stream,
|
|
start_stage=start_stage,
|
|
end_stage=end_stage,
|
|
tts_audio_output="wav",
|
|
pipeline_id=pipeline_id,
|
|
audio_settings=assist_pipeline.AudioSettings(
|
|
noise_suppression_level=self.device.noise_suppression_level,
|
|
auto_gain_dbfs=self.device.auto_gain,
|
|
volume_multiplier=self.device.volume_multiplier,
|
|
),
|
|
device_id=self.device.device_id,
|
|
)
|
|
)
|
|
|
|
# Run until pipeline is complete or cancelled with an empty audio chunk
|
|
while self._is_pipeline_running:
|
|
client_event = await self._client.read_event()
|
|
if client_event is None:
|
|
raise ConnectionResetError("Satellite disconnected")
|
|
|
|
if AudioChunk.is_type(client_event.type):
|
|
# Microphone audio
|
|
chunk = AudioChunk.from_event(client_event)
|
|
chunk = self._chunk_converter.convert(chunk)
|
|
self._audio_queue.put_nowait(chunk.audio)
|
|
elif AudioStop.is_type(client_event.type):
|
|
# Stop pipeline
|
|
_LOGGER.debug("Client requested pipeline to stop")
|
|
self._audio_queue.put_nowait(b"")
|
|
break
|
|
else:
|
|
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
|
|
|
|
# Ensure task finishes
|
|
await _pipeline_task
|
|
|
|
_LOGGER.debug("Pipeline finished")
|
|
|
|
def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None:
|
|
"""Translate pipeline events into Wyoming events."""
|
|
assert self._client is not None
|
|
|
|
if event.type == assist_pipeline.PipelineEventType.RUN_END:
|
|
# Pipeline run is complete
|
|
self._is_pipeline_running = False
|
|
self.device.set_is_active(False)
|
|
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
|
|
self.hass.add_job(self._client.write_event(Detect().event()))
|
|
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END:
|
|
# Wake word detection
|
|
self.device.set_is_active(True)
|
|
|
|
# Inform client of wake word detection
|
|
if event.data and (wake_word_output := event.data.get("wake_word_output")):
|
|
detection = Detection(
|
|
name=wake_word_output["wake_word_id"],
|
|
timestamp=wake_word_output.get("timestamp"),
|
|
)
|
|
self.hass.add_job(self._client.write_event(detection.event()))
|
|
elif event.type == assist_pipeline.PipelineEventType.STT_START:
|
|
# Speech-to-text
|
|
self.device.set_is_active(True)
|
|
|
|
if event.data:
|
|
self.hass.add_job(
|
|
self._client.write_event(
|
|
Transcribe(language=event.data["metadata"]["language"]).event()
|
|
)
|
|
)
|
|
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START:
|
|
# User started speaking
|
|
if event.data:
|
|
self.hass.add_job(
|
|
self._client.write_event(
|
|
VoiceStarted(timestamp=event.data["timestamp"]).event()
|
|
)
|
|
)
|
|
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END:
|
|
# User stopped speaking
|
|
if event.data:
|
|
self.hass.add_job(
|
|
self._client.write_event(
|
|
VoiceStopped(timestamp=event.data["timestamp"]).event()
|
|
)
|
|
)
|
|
elif event.type == assist_pipeline.PipelineEventType.STT_END:
|
|
# Speech-to-text transcript
|
|
if event.data:
|
|
# Inform client of transript
|
|
stt_text = event.data["stt_output"]["text"]
|
|
self.hass.add_job(
|
|
self._client.write_event(Transcript(text=stt_text).event())
|
|
)
|
|
elif event.type == assist_pipeline.PipelineEventType.TTS_START:
|
|
# Text-to-speech text
|
|
if event.data:
|
|
# Inform client of text
|
|
self.hass.add_job(
|
|
self._client.write_event(
|
|
Synthesize(
|
|
text=event.data["tts_input"],
|
|
voice=SynthesizeVoice(
|
|
name=event.data.get("voice"),
|
|
language=event.data.get("language"),
|
|
),
|
|
).event()
|
|
)
|
|
)
|
|
elif event.type == assist_pipeline.PipelineEventType.TTS_END:
|
|
# TTS stream
|
|
if event.data and (tts_output := event.data["tts_output"]):
|
|
media_id = tts_output["media_id"]
|
|
self.hass.add_job(self._stream_tts(media_id))
|
|
elif event.type == assist_pipeline.PipelineEventType.ERROR:
|
|
# Pipeline error
|
|
if event.data:
|
|
self.hass.add_job(
|
|
self._client.write_event(
|
|
Error(
|
|
text=event.data["message"], code=event.data["code"]
|
|
).event()
|
|
)
|
|
)
|
|
|
|
async def _connect(self) -> None:
|
|
"""Connect to satellite over TCP."""
|
|
await self._disconnect()
|
|
|
|
_LOGGER.debug(
|
|
"Connecting to satellite at %s:%s", self.service.host, self.service.port
|
|
)
|
|
self._client = AsyncTcpClient(self.service.host, self.service.port)
|
|
await self._client.connect()
|
|
|
|
async def _disconnect(self) -> None:
|
|
"""Disconnect if satellite is currently connected."""
|
|
if self._client is None:
|
|
return
|
|
|
|
_LOGGER.debug("Disconnecting from satellite")
|
|
await self._client.disconnect()
|
|
self._client = None
|
|
|
|
async def _stream_tts(self, media_id: str) -> None:
|
|
"""Stream TTS WAV audio to satellite in chunks."""
|
|
assert self._client is not None
|
|
|
|
extension, data = await tts.async_get_media_source_audio(self.hass, media_id)
|
|
if extension != "wav":
|
|
raise ValueError(f"Cannot stream audio format to satellite: {extension}")
|
|
|
|
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
|
sample_rate = wav_file.getframerate()
|
|
sample_width = wav_file.getsampwidth()
|
|
sample_channels = wav_file.getnchannels()
|
|
_LOGGER.debug("Streaming %s TTS sample(s)", wav_file.getnframes())
|
|
|
|
timestamp = 0
|
|
await self._client.write_event(
|
|
AudioStart(
|
|
rate=sample_rate,
|
|
width=sample_width,
|
|
channels=sample_channels,
|
|
timestamp=timestamp,
|
|
).event()
|
|
)
|
|
|
|
# Stream audio chunks
|
|
while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK):
|
|
chunk = AudioChunk(
|
|
rate=sample_rate,
|
|
width=sample_width,
|
|
channels=sample_channels,
|
|
audio=audio_bytes,
|
|
timestamp=timestamp,
|
|
)
|
|
await self._client.write_event(chunk.event())
|
|
timestamp += chunk.seconds
|
|
|
|
await self._client.write_event(AudioStop(timestamp=timestamp).event())
|
|
_LOGGER.debug("TTS streaming complete")
|
|
|
|
async def _stt_stream(self) -> AsyncGenerator[bytes, None]:
|
|
"""Yield audio chunks from a queue."""
|
|
is_first_chunk = True
|
|
while chunk := await self._audio_queue.get():
|
|
if is_first_chunk:
|
|
is_first_chunk = False
|
|
_LOGGER.debug("Receiving audio from satellite")
|
|
|
|
yield chunk
|