core/homeassistant/components/esphome/voice_assistant.py

407 lines
14 KiB
Python
Raw Normal View History

"""ESPHome voice assistant support."""
from __future__ import annotations
import asyncio
from collections import deque
from collections.abc import AsyncIterable, Callable, MutableSequence, Sequence
import logging
import socket
from typing import cast
from aioesphomeapi import VoiceAssistantEventType
from homeassistant.components import stt, tts
from homeassistant.components.assist_pipeline import (
PipelineEvent,
PipelineEventType,
2023-05-31 15:06:03 +00:00
PipelineNotFound,
async_pipeline_from_audio_stream,
select as pipeline_select,
)
from homeassistant.components.assist_pipeline.vad import (
VadSensitivity,
VoiceCommandSegmenter,
)
from homeassistant.components.media_player import async_process_play_media_url
from homeassistant.core import Context, HomeAssistant, callback
from .const import DOMAIN
from .entry_data import RuntimeEntryData
from .enum_mapper import EsphomeEnumMapper
_LOGGER = logging.getLogger(__name__)
UDP_PORT = 0 # Set to 0 to let the OS pick a free random port
UDP_MAX_PACKET_SIZE = 1024
_VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
VoiceAssistantEventType, PipelineEventType
] = EsphomeEnumMapper(
{
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR,
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START,
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START,
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END,
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START,
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END,
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START,
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END,
}
)
class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
"""Receive UDP packets and forward them to the voice assistant."""
started = False
stopped = False
transport: asyncio.DatagramTransport | None = None
remote_addr: tuple[str, int] | None = None
def __init__(
self,
hass: HomeAssistant,
entry_data: RuntimeEntryData,
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
handle_finished: Callable[[], None],
audio_timeout: float = 2.0,
) -> None:
"""Initialize UDP receiver."""
self.context = Context()
self.hass = hass
assert entry_data.device_info is not None
self.device_info = entry_data.device_info
self.queue: asyncio.Queue[bytes] = asyncio.Queue()
self.handle_event = handle_event
self.handle_finished = handle_finished
self._tts_done = asyncio.Event()
self.audio_timeout = audio_timeout
async def start_server(self) -> int:
"""Start accepting connections."""
def accept_connection() -> VoiceAssistantUDPServer:
"""Accept connection."""
if self.started:
raise RuntimeError("Can only start once")
if self.stopped:
raise RuntimeError("No longer accepting connections")
self.started = True
return self
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(("", UDP_PORT))
await asyncio.get_running_loop().create_datagram_endpoint(
accept_connection, sock=sock
)
return cast(int, sock.getsockname()[1])
@callback
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Store transport for later use."""
self.transport = cast(asyncio.DatagramTransport, transport)
@callback
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
"""Handle incoming UDP packet."""
if not self.started or self.stopped:
return
if self.remote_addr is None:
self.remote_addr = addr
self.queue.put_nowait(data)
def error_received(self, exc: Exception) -> None:
"""Handle when a send or receive operation raises an OSError.
(Other than BlockingIOError or InterruptedError.)
"""
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
self.handle_finished()
@callback
def stop(self) -> None:
"""Stop the receiver."""
self.queue.put_nowait(b"")
self.started = False
self.stopped = True
def close(self) -> None:
"""Close the receiver."""
self.started = False
self.stopped = True
if self.transport is not None:
self.transport.close()
async def _iterate_packets(self) -> AsyncIterable[bytes]:
"""Iterate over incoming packets."""
if not self.started or self.stopped:
raise RuntimeError("Not running")
while data := await self.queue.get():
yield data
def _event_callback(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""
try:
event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type)
except KeyError:
_LOGGER.warning("Received unknown pipeline event type: %s", event.type)
return
data_to_send = None
error = False
if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END:
assert event.data is not None
data_to_send = {"text": event.data["stt_output"]["text"]}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END:
assert event.data is not None
data_to_send = {
"conversation_id": event.data["intent_output"]["conversation_id"] or "",
}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START:
assert event.data is not None
data_to_send = {"text": event.data["tts_input"]}
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END:
assert event.data is not None
path = event.data["tts_output"]["url"]
url = async_process_play_media_url(self.hass, path)
data_to_send = {"url": url}
if self.device_info.voice_assistant_version >= 2:
media_id = event.data["tts_output"]["media_id"]
self.hass.async_create_background_task(
self._send_tts(media_id), "esphome_voice_assistant_tts"
)
else:
self._tts_done.set()
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR:
assert event.data is not None
data_to_send = {
"code": event.data["code"],
"message": event.data["message"],
}
self._tts_done.set()
error = True
self.handle_event(event_type, data_to_send)
if error:
self.handle_finished()
async def _wait_for_speech(
self,
segmenter: VoiceCommandSegmenter,
chunk_buffer: MutableSequence[bytes],
) -> bool:
"""Buffer audio chunks until speech is detected.
Raises asyncio.TimeoutError if no audio data is retrievable from the queue (device stops sending packets / networking issue).
Returns True if speech was detected
Returns False if the connection was stopped gracefully (b"" put onto the queue).
"""
# Timeout if no audio comes in for a while.
2023-08-15 13:30:20 +00:00
async with asyncio.timeout(self.audio_timeout):
chunk = await self.queue.get()
while chunk:
segmenter.process(chunk)
# Buffer the data we have taken from the queue
chunk_buffer.append(chunk)
if segmenter.in_command:
return True
2023-08-15 13:30:20 +00:00
async with asyncio.timeout(self.audio_timeout):
chunk = await self.queue.get()
# If chunk is falsey, `stop()` was called
return False
async def _segment_audio(
self,
segmenter: VoiceCommandSegmenter,
chunk_buffer: Sequence[bytes],
) -> AsyncIterable[bytes]:
"""Yield audio chunks until voice command has finished.
Raises asyncio.TimeoutError if no audio data is retrievable from the queue.
"""
# Buffered chunks first
for buffered_chunk in chunk_buffer:
yield buffered_chunk
# Timeout if no audio comes in for a while.
2023-08-15 13:30:20 +00:00
async with asyncio.timeout(self.audio_timeout):
chunk = await self.queue.get()
while chunk:
if not segmenter.process(chunk):
# Voice command is finished
break
yield chunk
2023-08-15 13:30:20 +00:00
async with asyncio.timeout(self.audio_timeout):
chunk = await self.queue.get()
async def _iterate_packets_with_vad(
self, pipeline_timeout: float, silence_seconds: float
) -> Callable[[], AsyncIterable[bytes]] | None:
segmenter = VoiceCommandSegmenter(silence_seconds=silence_seconds)
chunk_buffer: deque[bytes] = deque(maxlen=100)
try:
2023-08-15 13:30:20 +00:00
async with asyncio.timeout(pipeline_timeout):
speech_detected = await self._wait_for_speech(segmenter, chunk_buffer)
if not speech_detected:
_LOGGER.debug(
"Device stopped sending audio before speech was detected"
)
self.handle_finished()
return None
except asyncio.TimeoutError:
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{
"code": "speech-timeout",
"message": "Timed out waiting for speech",
},
)
self.handle_finished()
return None
async def _stream_packets() -> AsyncIterable[bytes]:
try:
async for chunk in self._segment_audio(segmenter, chunk_buffer):
yield chunk
except asyncio.TimeoutError:
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{
"code": "speech-timeout",
"message": "No speech detected",
},
)
self.handle_finished()
return _stream_packets
async def run_pipeline(
self,
device_id: str,
conversation_id: str | None,
use_vad: bool = False,
pipeline_timeout: float = 30.0,
) -> None:
"""Run the Voice Assistant pipeline."""
tts_audio_output = (
"raw" if self.device_info.voice_assistant_version >= 2 else "mp3"
)
if use_vad:
stt_stream = await self._iterate_packets_with_vad(
pipeline_timeout,
silence_seconds=VadSensitivity.to_seconds(
pipeline_select.get_vad_sensitivity(
self.hass,
DOMAIN,
self.device_info.mac_address,
)
),
)
# Error or timeout occurred and was handled already
if stt_stream is None:
return
else:
stt_stream = self._iterate_packets
_LOGGER.debug("Starting pipeline")
try:
2023-08-15 13:30:20 +00:00
async with asyncio.timeout(pipeline_timeout):
await async_pipeline_from_audio_stream(
self.hass,
context=self.context,
event_callback=self._event_callback,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=stt_stream(),
pipeline_id=pipeline_select.get_chosen_pipeline(
self.hass, DOMAIN, self.device_info.mac_address
),
conversation_id=conversation_id,
device_id=device_id,
tts_audio_output=tts_audio_output,
)
# Block until TTS is done sending
await self._tts_done.wait()
_LOGGER.debug("Pipeline finished")
2023-05-31 15:06:03 +00:00
except PipelineNotFound:
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{
"code": "pipeline not found",
"message": "Selected pipeline timeout",
},
)
_LOGGER.warning("Pipeline not found")
except asyncio.TimeoutError:
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{
"code": "pipeline-timeout",
"message": "Pipeline timeout",
},
)
_LOGGER.warning("Pipeline timeout")
finally:
self.handle_finished()
async def _send_tts(self, media_id: str) -> None:
"""Send TTS audio to device via UDP."""
try:
if self.transport is None:
return
_extension, audio_bytes = await tts.async_get_media_source_audio(
self.hass,
media_id,
)
_LOGGER.debug("Sending %d bytes of audio", len(audio_bytes))
bytes_per_sample = stt.AudioBitRates.BITRATE_16 // 8
sample_offset = 0
samples_left = len(audio_bytes) // bytes_per_sample
while samples_left > 0:
bytes_offset = sample_offset * bytes_per_sample
chunk: bytes = audio_bytes[bytes_offset : bytes_offset + 1024]
samples_in_chunk = len(chunk) // bytes_per_sample
samples_left -= samples_in_chunk
self.transport.sendto(chunk, self.remote_addr)
await asyncio.sleep(
samples_in_chunk / stt.AudioSampleRates.SAMPLERATE_16000 * 0.99
)
sample_offset += samples_in_chunk
finally:
self._tts_done.set()