2023-04-12 00:25:05 +00:00
|
|
|
"""Voice over IP (VoIP) implementation."""
|
2023-04-13 03:23:20 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2023-04-12 00:25:05 +00:00
|
|
|
import asyncio
|
2023-04-18 02:51:14 +00:00
|
|
|
from collections import deque
|
|
|
|
from collections.abc import AsyncIterable
|
2023-04-12 00:25:05 +00:00
|
|
|
import logging
|
|
|
|
import time
|
2023-04-13 03:23:20 +00:00
|
|
|
from typing import TYPE_CHECKING
|
2023-04-12 00:25:05 +00:00
|
|
|
|
|
|
|
import async_timeout
|
|
|
|
from voip_utils import CallInfo, RtpDatagramProtocol, SdpInfo, VoipDatagramProtocol
|
|
|
|
|
|
|
|
from homeassistant.components import stt, tts
|
2023-04-13 21:25:38 +00:00
|
|
|
from homeassistant.components.assist_pipeline import (
|
2023-04-12 00:25:05 +00:00
|
|
|
Pipeline,
|
|
|
|
PipelineEvent,
|
|
|
|
PipelineEventType,
|
|
|
|
async_pipeline_from_audio_stream,
|
2023-04-17 17:09:11 +00:00
|
|
|
select as pipeline_select,
|
2023-04-12 00:25:05 +00:00
|
|
|
)
|
2023-04-13 21:25:38 +00:00
|
|
|
from homeassistant.components.assist_pipeline.vad import VoiceCommandSegmenter
|
2023-04-12 00:25:05 +00:00
|
|
|
from homeassistant.const import __version__
|
2023-04-19 13:30:29 +00:00
|
|
|
from homeassistant.core import Context, HomeAssistant
|
2023-04-12 00:25:05 +00:00
|
|
|
|
2023-04-17 17:09:11 +00:00
|
|
|
from .const import DOMAIN
|
|
|
|
|
2023-04-13 03:23:20 +00:00
|
|
|
if TYPE_CHECKING:
|
2023-04-17 02:59:05 +00:00
|
|
|
from .devices import VoIPDevice, VoIPDevices
|
2023-04-13 03:23:20 +00:00
|
|
|
|
2023-04-18 02:51:14 +00:00
|
|
|
_BUFFERED_CHUNKS_BEFORE_SPEECH = 100 # ~2 seconds
|
2023-04-12 00:25:05 +00:00
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class HassVoipDatagramProtocol(VoipDatagramProtocol):
|
|
|
|
"""HA UDP server for Voice over IP (VoIP)."""
|
|
|
|
|
2023-04-13 03:23:20 +00:00
|
|
|
def __init__(self, hass: HomeAssistant, devices: VoIPDevices) -> None:
|
2023-04-12 00:25:05 +00:00
|
|
|
"""Set up VoIP call handler."""
|
|
|
|
super().__init__(
|
|
|
|
sdp_info=SdpInfo(
|
|
|
|
username="homeassistant",
|
|
|
|
id=time.monotonic_ns(),
|
|
|
|
session_name="voip_hass",
|
|
|
|
version=__version__,
|
|
|
|
),
|
|
|
|
protocol_factory=lambda call_info: PipelineRtpDatagramProtocol(
|
|
|
|
hass,
|
|
|
|
hass.config.language,
|
2023-04-17 02:59:05 +00:00
|
|
|
devices.async_get_or_create(call_info),
|
2023-04-12 00:25:05 +00:00
|
|
|
),
|
|
|
|
)
|
2023-04-17 02:59:05 +00:00
|
|
|
self.hass = hass
|
2023-04-13 03:23:20 +00:00
|
|
|
self.devices = devices
|
2023-04-12 00:25:05 +00:00
|
|
|
|
|
|
|
def is_valid_call(self, call_info: CallInfo) -> bool:
|
|
|
|
"""Filter calls."""
|
2023-04-17 02:59:05 +00:00
|
|
|
device = self.devices.async_get_or_create(call_info)
|
|
|
|
return device.async_allow_call(self.hass)
|
2023-04-12 00:25:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
|
|
|
"""Run a voice assistant pipeline in a loop for a VoIP call."""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
hass: HomeAssistant,
|
|
|
|
language: str,
|
2023-04-17 02:59:05 +00:00
|
|
|
voip_device: VoIPDevice,
|
2023-04-12 00:25:05 +00:00
|
|
|
pipeline_timeout: float = 30.0,
|
|
|
|
audio_timeout: float = 2.0,
|
|
|
|
) -> None:
|
|
|
|
"""Set up pipeline RTP server."""
|
|
|
|
# STT expects 16Khz mono with 16-bit samples
|
|
|
|
super().__init__(rate=16000, width=2, channels=1)
|
|
|
|
|
|
|
|
self.hass = hass
|
|
|
|
self.language = language
|
2023-04-17 02:59:05 +00:00
|
|
|
self.voip_device = voip_device
|
2023-04-12 00:25:05 +00:00
|
|
|
self.pipeline: Pipeline | None = None
|
|
|
|
self.pipeline_timeout = pipeline_timeout
|
|
|
|
self.audio_timeout = audio_timeout
|
|
|
|
|
2023-04-13 03:23:20 +00:00
|
|
|
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
2023-04-19 13:30:29 +00:00
|
|
|
self._context = Context()
|
2023-04-12 00:25:05 +00:00
|
|
|
self._conversation_id: str | None = None
|
2023-04-19 13:30:29 +00:00
|
|
|
self._pipeline_task: asyncio.Task | None = None
|
2023-04-12 00:25:05 +00:00
|
|
|
|
|
|
|
def connection_made(self, transport):
|
|
|
|
"""Server is ready."""
|
2023-04-17 02:59:05 +00:00
|
|
|
super().connection_made(transport)
|
|
|
|
self.voip_device.set_is_active(True)
|
|
|
|
|
|
|
|
def connection_lost(self, exc):
|
|
|
|
"""Handle connection is lost or closed."""
|
|
|
|
super().connection_lost(exc)
|
|
|
|
self.voip_device.set_is_active(False)
|
2023-04-12 00:25:05 +00:00
|
|
|
|
|
|
|
def on_chunk(self, audio_bytes: bytes) -> None:
|
|
|
|
"""Handle raw audio chunk."""
|
|
|
|
if self._pipeline_task is None:
|
2023-04-18 02:51:14 +00:00
|
|
|
self._clear_audio_queue()
|
2023-04-12 00:25:05 +00:00
|
|
|
|
|
|
|
# Run pipeline until voice command finishes, then start over
|
|
|
|
self._pipeline_task = self.hass.async_create_background_task(
|
|
|
|
self._run_pipeline(),
|
|
|
|
"voip_pipeline_run",
|
|
|
|
)
|
|
|
|
|
|
|
|
self._audio_queue.put_nowait(audio_bytes)
|
|
|
|
|
|
|
|
async def _run_pipeline(
|
|
|
|
self,
|
|
|
|
) -> None:
|
|
|
|
"""Forward audio to pipeline STT and handle TTS."""
|
|
|
|
_LOGGER.debug("Starting pipeline")
|
|
|
|
|
|
|
|
async def stt_stream():
|
|
|
|
try:
|
2023-04-18 02:51:14 +00:00
|
|
|
async for chunk in self._segment_audio():
|
2023-04-12 00:25:05 +00:00
|
|
|
yield chunk
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
# Expected after caller hangs up
|
|
|
|
_LOGGER.debug("Audio timeout")
|
|
|
|
|
|
|
|
if self.transport is not None:
|
|
|
|
self.transport.close()
|
|
|
|
self.transport = None
|
2023-04-18 02:51:14 +00:00
|
|
|
finally:
|
|
|
|
self._clear_audio_queue()
|
2023-04-12 00:25:05 +00:00
|
|
|
|
|
|
|
try:
|
|
|
|
# Run pipeline with a timeout
|
|
|
|
async with async_timeout.timeout(self.pipeline_timeout):
|
|
|
|
await async_pipeline_from_audio_stream(
|
|
|
|
self.hass,
|
2023-04-19 13:30:29 +00:00
|
|
|
context=self._context,
|
2023-04-12 00:25:05 +00:00
|
|
|
event_callback=self._event_callback,
|
|
|
|
stt_metadata=stt.SpeechMetadata(
|
|
|
|
language="", # set in async_pipeline_from_audio_stream
|
|
|
|
format=stt.AudioFormats.WAV,
|
|
|
|
codec=stt.AudioCodecs.PCM,
|
|
|
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
|
|
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
|
|
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
|
|
|
),
|
|
|
|
stt_stream=stt_stream(),
|
2023-04-17 17:09:11 +00:00
|
|
|
pipeline_id=pipeline_select.get_chosen_pipeline(
|
|
|
|
self.hass, DOMAIN, self.voip_device.voip_id
|
|
|
|
),
|
2023-04-12 00:25:05 +00:00
|
|
|
conversation_id=self._conversation_id,
|
|
|
|
tts_options={tts.ATTR_AUDIO_OUTPUT: "raw"},
|
|
|
|
)
|
|
|
|
|
|
|
|
except asyncio.TimeoutError:
|
|
|
|
# Expected after caller hangs up
|
|
|
|
_LOGGER.debug("Pipeline timeout")
|
|
|
|
|
|
|
|
if self.transport is not None:
|
|
|
|
self.transport.close()
|
|
|
|
self.transport = None
|
|
|
|
finally:
|
|
|
|
# Allow pipeline to run again
|
|
|
|
self._pipeline_task = None
|
|
|
|
|
2023-04-18 02:51:14 +00:00
|
|
|
async def _segment_audio(self) -> AsyncIterable[bytes]:
|
|
|
|
segmenter = VoiceCommandSegmenter()
|
|
|
|
chunk_buffer: deque[bytes] = deque(maxlen=_BUFFERED_CHUNKS_BEFORE_SPEECH)
|
|
|
|
|
|
|
|
# Timeout if no audio comes in for a while.
|
|
|
|
# This means the caller hung up.
|
|
|
|
async with async_timeout.timeout(self.audio_timeout):
|
|
|
|
chunk = await self._audio_queue.get()
|
|
|
|
|
|
|
|
while chunk:
|
|
|
|
if not segmenter.process(chunk):
|
|
|
|
# Voice command is finished
|
|
|
|
break
|
|
|
|
|
|
|
|
if segmenter.in_command:
|
|
|
|
if chunk_buffer:
|
|
|
|
# Release audio in buffer first
|
|
|
|
for buffered_chunk in chunk_buffer:
|
|
|
|
yield buffered_chunk
|
|
|
|
|
|
|
|
chunk_buffer.clear()
|
|
|
|
|
|
|
|
yield chunk
|
|
|
|
else:
|
|
|
|
# Buffer until command starts
|
|
|
|
chunk_buffer.append(chunk)
|
|
|
|
|
|
|
|
async with async_timeout.timeout(self.audio_timeout):
|
|
|
|
chunk = await self._audio_queue.get()
|
|
|
|
|
|
|
|
def _clear_audio_queue(self) -> None:
|
|
|
|
while not self._audio_queue.empty():
|
|
|
|
self._audio_queue.get_nowait()
|
|
|
|
|
2023-04-12 00:25:05 +00:00
|
|
|
def _event_callback(self, event: PipelineEvent):
|
|
|
|
if not event.data:
|
|
|
|
return
|
|
|
|
|
|
|
|
if event.type == PipelineEventType.INTENT_END:
|
|
|
|
# Capture conversation id
|
|
|
|
self._conversation_id = event.data["intent_output"]["conversation_id"]
|
|
|
|
elif event.type == PipelineEventType.TTS_END:
|
|
|
|
# Send TTS audio to caller over RTP
|
|
|
|
media_id = event.data["tts_output"]["media_id"]
|
|
|
|
self.hass.async_create_background_task(
|
|
|
|
self._send_media(media_id),
|
|
|
|
"voip_pipeline_tts",
|
|
|
|
)
|
|
|
|
|
|
|
|
async def _send_media(self, media_id: str) -> None:
|
|
|
|
"""Send TTS audio to caller via RTP."""
|
|
|
|
if self.transport is None:
|
|
|
|
return
|
|
|
|
|
|
|
|
_extension, audio_bytes = await tts.async_get_media_source_audio(
|
|
|
|
self.hass,
|
|
|
|
media_id,
|
|
|
|
)
|
|
|
|
|
|
|
|
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
|
|
|
|
|
|
|
# Assume TTS audio is 16Khz 16-bit mono
|
|
|
|
await self.send_audio(audio_bytes, rate=16000, width=2, channels=1)
|