528 lines
18 KiB
Python
528 lines
18 KiB
Python
"""Voice over IP (VoIP) implementation."""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections import deque
|
|
from collections.abc import AsyncIterable, MutableSequence, Sequence
|
|
from functools import partial
|
|
import logging
|
|
from pathlib import Path
|
|
import time
|
|
from typing import TYPE_CHECKING
|
|
|
|
from voip_utils import (
|
|
CallInfo,
|
|
RtcpState,
|
|
RtpDatagramProtocol,
|
|
SdpInfo,
|
|
VoipDatagramProtocol,
|
|
)
|
|
|
|
from homeassistant.components import stt, tts
|
|
from homeassistant.components.assist_pipeline import (
|
|
Pipeline,
|
|
PipelineEvent,
|
|
PipelineEventType,
|
|
PipelineNotFound,
|
|
async_get_pipeline,
|
|
async_pipeline_from_audio_stream,
|
|
select as pipeline_select,
|
|
)
|
|
from homeassistant.components.assist_pipeline.vad import (
|
|
AudioBuffer,
|
|
VadSensitivity,
|
|
VoiceActivityDetector,
|
|
VoiceCommandSegmenter,
|
|
WebRtcVad,
|
|
)
|
|
from homeassistant.const import __version__
|
|
from homeassistant.core import Context, HomeAssistant
|
|
from homeassistant.util.ulid import ulid
|
|
|
|
from .const import CHANNELS, DOMAIN, RATE, RTP_AUDIO_SETTINGS, WIDTH
|
|
|
|
if TYPE_CHECKING:
|
|
from .devices import VoIPDevice, VoIPDevices
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
def make_protocol(
|
|
hass: HomeAssistant,
|
|
devices: VoIPDevices,
|
|
call_info: CallInfo,
|
|
rtcp_state: RtcpState | None = None,
|
|
) -> VoipDatagramProtocol:
|
|
"""Plays a pre-recorded message if pipeline is misconfigured."""
|
|
voip_device = devices.async_get_or_create(call_info)
|
|
pipeline_id = pipeline_select.get_chosen_pipeline(
|
|
hass,
|
|
DOMAIN,
|
|
voip_device.voip_id,
|
|
)
|
|
try:
|
|
pipeline: Pipeline | None = async_get_pipeline(hass, pipeline_id)
|
|
except PipelineNotFound:
|
|
pipeline = None
|
|
|
|
if (
|
|
(pipeline is None)
|
|
or (pipeline.stt_engine is None)
|
|
or (pipeline.tts_engine is None)
|
|
):
|
|
# Play pre-recorded message instead of failing
|
|
return PreRecordMessageProtocol(
|
|
hass,
|
|
"problem.pcm",
|
|
opus_payload_type=call_info.opus_payload_type,
|
|
rtcp_state=rtcp_state,
|
|
)
|
|
|
|
vad_sensitivity = pipeline_select.get_vad_sensitivity(
|
|
hass,
|
|
DOMAIN,
|
|
voip_device.voip_id,
|
|
)
|
|
|
|
# Pipeline is properly configured
|
|
return PipelineRtpDatagramProtocol(
|
|
hass,
|
|
hass.config.language,
|
|
voip_device,
|
|
Context(user_id=devices.config_entry.data["user"]),
|
|
opus_payload_type=call_info.opus_payload_type,
|
|
silence_seconds=VadSensitivity.to_seconds(vad_sensitivity),
|
|
rtcp_state=rtcp_state,
|
|
)
|
|
|
|
|
|
class HassVoipDatagramProtocol(VoipDatagramProtocol):
|
|
"""HA UDP server for Voice over IP (VoIP)."""
|
|
|
|
def __init__(self, hass: HomeAssistant, devices: VoIPDevices) -> None:
|
|
"""Set up VoIP call handler."""
|
|
super().__init__(
|
|
sdp_info=SdpInfo(
|
|
username="homeassistant",
|
|
id=time.monotonic_ns(),
|
|
session_name="voip_hass",
|
|
version=__version__,
|
|
),
|
|
valid_protocol_factory=lambda call_info, rtcp_state: make_protocol(
|
|
hass, devices, call_info, rtcp_state
|
|
),
|
|
invalid_protocol_factory=lambda call_info, rtcp_state: PreRecordMessageProtocol(
|
|
hass,
|
|
"not_configured.pcm",
|
|
opus_payload_type=call_info.opus_payload_type,
|
|
rtcp_state=rtcp_state,
|
|
),
|
|
)
|
|
self.hass = hass
|
|
self.devices = devices
|
|
self._closed_event = asyncio.Event()
|
|
|
|
def is_valid_call(self, call_info: CallInfo) -> bool:
|
|
"""Filter calls."""
|
|
device = self.devices.async_get_or_create(call_info)
|
|
return device.async_allow_call(self.hass)
|
|
|
|
def connection_lost(self, exc):
|
|
"""Signal wait_closed when transport is completely closed."""
|
|
self.hass.loop.call_soon_threadsafe(self._closed_event.set)
|
|
|
|
async def wait_closed(self) -> None:
|
|
"""Wait for connection_lost to be called."""
|
|
await self._closed_event.wait()
|
|
|
|
|
|
class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
|
"""Run a voice assistant pipeline in a loop for a VoIP call."""
|
|
|
|
def __init__(
|
|
self,
|
|
hass: HomeAssistant,
|
|
language: str,
|
|
voip_device: VoIPDevice,
|
|
context: Context,
|
|
opus_payload_type: int,
|
|
pipeline_timeout: float = 30.0,
|
|
audio_timeout: float = 2.0,
|
|
buffered_chunks_before_speech: int = 100,
|
|
listening_tone_enabled: bool = True,
|
|
processing_tone_enabled: bool = True,
|
|
error_tone_enabled: bool = True,
|
|
tone_delay: float = 0.2,
|
|
tts_extra_timeout: float = 1.0,
|
|
silence_seconds: float = 1.0,
|
|
rtcp_state: RtcpState | None = None,
|
|
) -> None:
|
|
"""Set up pipeline RTP server."""
|
|
super().__init__(
|
|
rate=RATE,
|
|
width=WIDTH,
|
|
channels=CHANNELS,
|
|
opus_payload_type=opus_payload_type,
|
|
rtcp_state=rtcp_state,
|
|
)
|
|
|
|
self.hass = hass
|
|
self.language = language
|
|
self.voip_device = voip_device
|
|
self.pipeline: Pipeline | None = None
|
|
self.pipeline_timeout = pipeline_timeout
|
|
self.audio_timeout = audio_timeout
|
|
self.buffered_chunks_before_speech = buffered_chunks_before_speech
|
|
self.listening_tone_enabled = listening_tone_enabled
|
|
self.processing_tone_enabled = processing_tone_enabled
|
|
self.error_tone_enabled = error_tone_enabled
|
|
self.tone_delay = tone_delay
|
|
self.tts_extra_timeout = tts_extra_timeout
|
|
self.silence_seconds = silence_seconds
|
|
|
|
self._audio_queue: asyncio.Queue[bytes] = asyncio.Queue()
|
|
self._context = context
|
|
self._conversation_id: str | None = None
|
|
self._pipeline_task: asyncio.Task | None = None
|
|
self._tts_done = asyncio.Event()
|
|
self._session_id: str | None = None
|
|
self._tone_bytes: bytes | None = None
|
|
self._processing_bytes: bytes | None = None
|
|
self._error_bytes: bytes | None = None
|
|
self._pipeline_error: bool = False
|
|
|
|
def connection_made(self, transport):
|
|
"""Server is ready."""
|
|
super().connection_made(transport)
|
|
self.voip_device.set_is_active(True)
|
|
|
|
def connection_lost(self, exc):
|
|
"""Handle connection is lost or closed."""
|
|
super().connection_lost(exc)
|
|
self.voip_device.set_is_active(False)
|
|
|
|
def on_chunk(self, audio_bytes: bytes) -> None:
|
|
"""Handle raw audio chunk."""
|
|
if self._pipeline_task is None:
|
|
self._clear_audio_queue()
|
|
|
|
# Run pipeline until voice command finishes, then start over
|
|
self._pipeline_task = self.hass.async_create_background_task(
|
|
self._run_pipeline(),
|
|
"voip_pipeline_run",
|
|
)
|
|
|
|
self._audio_queue.put_nowait(audio_bytes)
|
|
|
|
async def _run_pipeline(
|
|
self,
|
|
) -> None:
|
|
"""Forward audio to pipeline STT and handle TTS."""
|
|
if self._session_id is None:
|
|
self._session_id = ulid()
|
|
|
|
# Play listening tone at the start of each cycle
|
|
if self.listening_tone_enabled:
|
|
await self._play_listening_tone()
|
|
|
|
try:
|
|
# Wait for speech before starting pipeline
|
|
segmenter = VoiceCommandSegmenter(silence_seconds=self.silence_seconds)
|
|
vad = WebRtcVad()
|
|
chunk_buffer: deque[bytes] = deque(
|
|
maxlen=self.buffered_chunks_before_speech,
|
|
)
|
|
speech_detected = await self._wait_for_speech(
|
|
segmenter,
|
|
vad,
|
|
chunk_buffer,
|
|
)
|
|
if not speech_detected:
|
|
_LOGGER.debug("No speech detected")
|
|
return
|
|
|
|
_LOGGER.debug("Starting pipeline")
|
|
self._tts_done.clear()
|
|
|
|
async def stt_stream():
|
|
try:
|
|
async for chunk in self._segment_audio(
|
|
segmenter,
|
|
vad,
|
|
chunk_buffer,
|
|
):
|
|
yield chunk
|
|
|
|
if self.processing_tone_enabled:
|
|
await self._play_processing_tone()
|
|
except asyncio.TimeoutError:
|
|
# Expected after caller hangs up
|
|
_LOGGER.debug("Audio timeout")
|
|
self._session_id = None
|
|
self.disconnect()
|
|
finally:
|
|
self._clear_audio_queue()
|
|
|
|
# Run pipeline with a timeout
|
|
async with asyncio.timeout(self.pipeline_timeout):
|
|
await async_pipeline_from_audio_stream(
|
|
self.hass,
|
|
context=self._context,
|
|
event_callback=self._event_callback,
|
|
stt_metadata=stt.SpeechMetadata(
|
|
language="", # set in async_pipeline_from_audio_stream
|
|
format=stt.AudioFormats.WAV,
|
|
codec=stt.AudioCodecs.PCM,
|
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
|
),
|
|
stt_stream=stt_stream(),
|
|
pipeline_id=pipeline_select.get_chosen_pipeline(
|
|
self.hass, DOMAIN, self.voip_device.voip_id
|
|
),
|
|
conversation_id=self._conversation_id,
|
|
device_id=self.voip_device.device_id,
|
|
tts_audio_output="raw",
|
|
)
|
|
|
|
if self._pipeline_error:
|
|
self._pipeline_error = False
|
|
if self.error_tone_enabled:
|
|
await self._play_error_tone()
|
|
else:
|
|
# Block until TTS is done speaking.
|
|
#
|
|
# This is set in _send_tts and has a timeout that's based on the
|
|
# length of the TTS audio.
|
|
await self._tts_done.wait()
|
|
|
|
_LOGGER.debug("Pipeline finished")
|
|
except PipelineNotFound:
|
|
_LOGGER.warning("Pipeline not found")
|
|
except asyncio.TimeoutError:
|
|
# Expected after caller hangs up
|
|
_LOGGER.debug("Pipeline timeout")
|
|
self._session_id = None
|
|
self.disconnect()
|
|
finally:
|
|
# Allow pipeline to run again
|
|
self._pipeline_task = None
|
|
|
|
async def _wait_for_speech(
|
|
self,
|
|
segmenter: VoiceCommandSegmenter,
|
|
vad: VoiceActivityDetector,
|
|
chunk_buffer: MutableSequence[bytes],
|
|
):
|
|
"""Buffer audio chunks until speech is detected.
|
|
|
|
Returns True if speech was detected, False otherwise.
|
|
"""
|
|
# Timeout if no audio comes in for a while.
|
|
# This means the caller hung up.
|
|
async with asyncio.timeout(self.audio_timeout):
|
|
chunk = await self._audio_queue.get()
|
|
|
|
assert vad.samples_per_chunk is not None
|
|
vad_buffer = AudioBuffer(vad.samples_per_chunk * WIDTH)
|
|
|
|
while chunk:
|
|
chunk_buffer.append(chunk)
|
|
|
|
segmenter.process_with_vad(chunk, vad, vad_buffer)
|
|
if segmenter.in_command:
|
|
# Buffer until command starts
|
|
if len(vad_buffer) > 0:
|
|
chunk_buffer.append(vad_buffer.bytes())
|
|
|
|
return True
|
|
|
|
async with asyncio.timeout(self.audio_timeout):
|
|
chunk = await self._audio_queue.get()
|
|
|
|
return False
|
|
|
|
async def _segment_audio(
|
|
self,
|
|
segmenter: VoiceCommandSegmenter,
|
|
vad: VoiceActivityDetector,
|
|
chunk_buffer: Sequence[bytes],
|
|
) -> AsyncIterable[bytes]:
|
|
"""Yield audio chunks until voice command has finished."""
|
|
# Buffered chunks first
|
|
for buffered_chunk in chunk_buffer:
|
|
yield buffered_chunk
|
|
|
|
# Timeout if no audio comes in for a while.
|
|
# This means the caller hung up.
|
|
async with asyncio.timeout(self.audio_timeout):
|
|
chunk = await self._audio_queue.get()
|
|
|
|
assert vad.samples_per_chunk is not None
|
|
vad_buffer = AudioBuffer(vad.samples_per_chunk * WIDTH)
|
|
|
|
while chunk:
|
|
if not segmenter.process_with_vad(chunk, vad, vad_buffer):
|
|
# Voice command is finished
|
|
break
|
|
|
|
yield chunk
|
|
|
|
async with asyncio.timeout(self.audio_timeout):
|
|
chunk = await self._audio_queue.get()
|
|
|
|
def _clear_audio_queue(self) -> None:
|
|
while not self._audio_queue.empty():
|
|
self._audio_queue.get_nowait()
|
|
|
|
def _event_callback(self, event: PipelineEvent):
|
|
if not event.data:
|
|
return
|
|
|
|
if event.type == PipelineEventType.INTENT_END:
|
|
# Capture conversation id
|
|
self._conversation_id = event.data["intent_output"]["conversation_id"]
|
|
elif event.type == PipelineEventType.TTS_END:
|
|
# Send TTS audio to caller over RTP
|
|
media_id = event.data["tts_output"]["media_id"]
|
|
self.hass.async_create_background_task(
|
|
self._send_tts(media_id),
|
|
"voip_pipeline_tts",
|
|
)
|
|
elif event.type == PipelineEventType.ERROR:
|
|
# Play error tone instead of wait for TTS
|
|
self._pipeline_error = True
|
|
|
|
async def _send_tts(self, media_id: str) -> None:
|
|
"""Send TTS audio to caller via RTP."""
|
|
try:
|
|
if self.transport is None:
|
|
return
|
|
|
|
_extension, audio_bytes = await tts.async_get_media_source_audio(
|
|
self.hass,
|
|
media_id,
|
|
)
|
|
|
|
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
|
|
|
# Time out 1 second after TTS audio should be finished
|
|
tts_samples = len(audio_bytes) / (WIDTH * CHANNELS)
|
|
tts_seconds = tts_samples / RATE
|
|
|
|
async with asyncio.timeout(tts_seconds + self.tts_extra_timeout):
|
|
# Assume TTS audio is 16Khz 16-bit mono
|
|
await self._async_send_audio(audio_bytes)
|
|
except asyncio.TimeoutError as err:
|
|
_LOGGER.warning("TTS timeout")
|
|
raise err
|
|
finally:
|
|
# Signal pipeline to restart
|
|
self._tts_done.set()
|
|
|
|
async def _async_send_audio(self, audio_bytes: bytes, **kwargs):
|
|
"""Send audio in executor."""
|
|
await self.hass.async_add_executor_job(
|
|
partial(self.send_audio, audio_bytes, **RTP_AUDIO_SETTINGS, **kwargs)
|
|
)
|
|
|
|
async def _play_listening_tone(self) -> None:
|
|
"""Play a tone to indicate that Home Assistant is listening."""
|
|
if self._tone_bytes is None:
|
|
# Do I/O in executor
|
|
self._tone_bytes = await self.hass.async_add_executor_job(
|
|
self._load_pcm,
|
|
"tone.pcm",
|
|
)
|
|
|
|
await self._async_send_audio(
|
|
self._tone_bytes,
|
|
silence_before=self.tone_delay,
|
|
)
|
|
|
|
async def _play_processing_tone(self) -> None:
|
|
"""Play a tone to indicate that Home Assistant is processing the voice command."""
|
|
if self._processing_bytes is None:
|
|
# Do I/O in executor
|
|
self._processing_bytes = await self.hass.async_add_executor_job(
|
|
self._load_pcm,
|
|
"processing.pcm",
|
|
)
|
|
|
|
await self._async_send_audio(self._processing_bytes)
|
|
|
|
async def _play_error_tone(self) -> None:
|
|
"""Play a tone to indicate a pipeline error occurred."""
|
|
if self._error_bytes is None:
|
|
# Do I/O in executor
|
|
self._error_bytes = await self.hass.async_add_executor_job(
|
|
self._load_pcm,
|
|
"error.pcm",
|
|
)
|
|
|
|
await self._async_send_audio(self._error_bytes)
|
|
|
|
def _load_pcm(self, file_name: str) -> bytes:
|
|
"""Load raw audio (16Khz, 16-bit mono)."""
|
|
return (Path(__file__).parent / file_name).read_bytes()
|
|
|
|
|
|
class PreRecordMessageProtocol(RtpDatagramProtocol):
|
|
"""Plays a pre-recorded message on a loop."""
|
|
|
|
def __init__(
|
|
self,
|
|
hass: HomeAssistant,
|
|
file_name: str,
|
|
opus_payload_type: int,
|
|
message_delay: float = 1.0,
|
|
loop_delay: float = 2.0,
|
|
rtcp_state: RtcpState | None = None,
|
|
) -> None:
|
|
"""Set up RTP server."""
|
|
super().__init__(
|
|
rate=RATE,
|
|
width=WIDTH,
|
|
channels=CHANNELS,
|
|
opus_payload_type=opus_payload_type,
|
|
rtcp_state=rtcp_state,
|
|
)
|
|
self.hass = hass
|
|
self.file_name = file_name
|
|
self.message_delay = message_delay
|
|
self.loop_delay = loop_delay
|
|
self._audio_task: asyncio.Task | None = None
|
|
self._audio_bytes: bytes | None = None
|
|
|
|
def on_chunk(self, audio_bytes: bytes) -> None:
|
|
"""Handle raw audio chunk."""
|
|
if self.transport is None:
|
|
return
|
|
|
|
if self._audio_bytes is None:
|
|
# 16Khz, 16-bit mono audio message
|
|
file_path = Path(__file__).parent / self.file_name
|
|
self._audio_bytes = file_path.read_bytes()
|
|
|
|
if self._audio_task is None:
|
|
self._audio_task = self.hass.async_create_background_task(
|
|
self._play_message(),
|
|
"voip_not_connected",
|
|
)
|
|
|
|
async def _play_message(self) -> None:
|
|
await self.hass.async_add_executor_job(
|
|
partial(
|
|
self.send_audio,
|
|
self._audio_bytes,
|
|
silence_before=self.message_delay,
|
|
**RTP_AUDIO_SETTINGS,
|
|
)
|
|
)
|
|
|
|
await asyncio.sleep(self.loop_delay)
|
|
|
|
# Allow message to play again
|
|
self._audio_task = None
|