core/homeassistant/components/esphome/voice_assistant.py

167 lines
6.0 KiB
Python

"""ESPHome voice assistant support."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterable, Callable
import logging
import socket
from typing import cast
from aioesphomeapi import VoiceAssistantEventType
from homeassistant.components import stt
from homeassistant.components.assist_pipeline import (
PipelineEvent,
PipelineEventType,
async_pipeline_from_audio_stream,
)
from homeassistant.components.media_player import async_process_play_media_url
from homeassistant.core import Context, HomeAssistant, callback
from .enum_mapper import EsphomeEnumMapper
_LOGGER = logging.getLogger(__name__)
UDP_PORT = 0 # Set to 0 to let the OS pick a free random port
_VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
VoiceAssistantEventType, PipelineEventType
] = EsphomeEnumMapper(
{
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR,
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START,
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END,
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START,
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END,
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START,
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END,
}
)
class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
"""Receive UDP packets and forward them to the voice assistant."""
started = False
queue: asyncio.Queue[bytes] | None = None
transport: asyncio.DatagramTransport | None = None
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize UDP receiver."""
self.context = Context()
self.hass = hass
self.queue = asyncio.Queue()
async def start_server(self) -> int:
"""Start accepting connections."""
def accept_connection() -> VoiceAssistantUDPServer:
"""Accept connection."""
if self.started:
raise RuntimeError("Can only start once")
if self.queue is None:
raise RuntimeError("No longer accepting connections")
self.started = True
return self
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(("", UDP_PORT))
await asyncio.get_running_loop().create_datagram_endpoint(
accept_connection, sock=sock
)
return cast(int, sock.getsockname()[1])
@callback
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Store transport for later use."""
self.transport = cast(asyncio.DatagramTransport, transport)
@callback
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
"""Handle incoming UDP packet."""
if self.queue is not None:
self.queue.put_nowait(data)
def error_received(self, exc: Exception) -> None:
"""Handle when a send or receive operation raises an OSError.
(Other than BlockingIOError or InterruptedError.)
"""
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
@callback
def stop(self) -> None:
"""Stop the receiver."""
if self.queue is not None:
self.queue.put_nowait(b"")
self.queue = None
if self.transport is not None:
self.transport.close()
async def _iterate_packets(self) -> AsyncIterable[bytes]:
"""Iterate over incoming packets."""
if self.queue is None:
raise RuntimeError("Already stopped")
while data := await self.queue.get():
yield data
async def run_pipeline(
self,
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
) -> None:
"""Run the Voice Assistant pipeline."""
@callback
def handle_pipeline_event(event: PipelineEvent) -> None:
"""Handle pipeline events."""
try:
event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type)
except KeyError:
_LOGGER.warning("Received unknown pipeline event type: %s", event.type)
return
data_to_send = None
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
assert event.data is not None
data_to_send = {"text": event.data["stt_output"]["text"]}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
assert event.data is not None
data_to_send = {"text": event.data["tts_input"]}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
assert event.data is not None
path = event.data["tts_output"]["url"]
url = async_process_play_media_url(self.hass, path)
data_to_send = {"url": url}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
assert event.data is not None
data_to_send = {
"code": event.data["code"],
"message": event.data["message"],
}
handle_event(event_type, data_to_send)
await async_pipeline_from_audio_stream(
self.hass,
context=self.context,
event_callback=handle_pipeline_event,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=self._iterate_packets(),
)