Voip integration (#90945)
* Media playback working * Working on OPUS audio * Before rollback * Fix is_end * First working pipeline * Clean up * Remove asserts * Send HA version in SDP * Use async_pipeline_from_audio_stream * Use config flow with allowed IP * Satisfy ruff * Remove use of regex for SIP IP * Use voip-utils * Fix imports * Add Pipeline to __all__ * Fix voice assistant tests * Basic VoIP test * Run hassfest * Generate requirements * Bump voip utils (missing requirement) * Allow tts_options to be passed in to pipeline run * Add config flow tests * Update test snapshots * More tests * Remove get_extra_info * Appeasing the codebotpull/88626/head^2
parent
3a72054f93
commit
78fec33b17
|
@ -1312,6 +1312,8 @@ build.json @home-assistant/supervisor
|
|||
/tests/components/vlc_telnet/ @rodripf @MartinHjelmare
|
||||
/homeassistant/components/voice_assistant/ @balloob @synesthesiam
|
||||
/tests/components/voice_assistant/ @balloob @synesthesiam
|
||||
/homeassistant/components/voip/ @balloob @synesthesiam
|
||||
/tests/components/voip/ @balloob @synesthesiam
|
||||
/homeassistant/components/volumio/ @OnFreund
|
||||
/tests/components/volumio/ @OnFreund
|
||||
/homeassistant/components/volvooncall/ @molobrakos
|
||||
|
|
|
@ -10,6 +10,7 @@ from homeassistant.helpers.typing import ConfigType
|
|||
from .const import DOMAIN
|
||||
from .error import PipelineNotFound
|
||||
from .pipeline import (
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
PipelineEventCallback,
|
||||
PipelineEventType,
|
||||
|
@ -25,6 +26,7 @@ __all__ = (
|
|||
"DOMAIN",
|
||||
"async_setup",
|
||||
"async_pipeline_from_audio_stream",
|
||||
"Pipeline",
|
||||
"PipelineEvent",
|
||||
"PipelineEventType",
|
||||
)
|
||||
|
@ -47,6 +49,7 @@ async def async_pipeline_from_audio_stream(
|
|||
pipeline_id: str | None = None,
|
||||
conversation_id: str | None = None,
|
||||
context: Context | None = None,
|
||||
tts_options: dict | None = None,
|
||||
) -> None:
|
||||
"""Create an audio pipeline from an audio stream."""
|
||||
if language is None:
|
||||
|
@ -83,6 +86,7 @@ async def async_pipeline_from_audio_stream(
|
|||
start_stage=PipelineStage.STT,
|
||||
end_stage=PipelineStage.TTS,
|
||||
event_callback=event_callback,
|
||||
tts_options=tts_options,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
@ -174,6 +174,7 @@ class PipelineRun:
|
|||
stt_provider: stt.Provider | None = None
|
||||
intent_agent: str | None = None
|
||||
tts_engine: str | None = None
|
||||
tts_options: dict | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Set language for pipeline."""
|
||||
|
@ -357,12 +358,17 @@ class PipelineRun:
|
|||
message=f"Text to speech engine '{engine}' not found",
|
||||
)
|
||||
|
||||
if not await tts.async_support_options(self.hass, engine, self.language):
|
||||
if not await tts.async_support_options(
|
||||
self.hass,
|
||||
engine,
|
||||
self.language,
|
||||
self.tts_options,
|
||||
):
|
||||
raise TextToSpeechError(
|
||||
code="tts-not-supported",
|
||||
message=(
|
||||
f"Text to speech engine {engine} "
|
||||
f"does not support language {self.language}"
|
||||
f"does not support language {self.language} or options {self.tts_options}"
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -385,14 +391,16 @@ class PipelineRun:
|
|||
|
||||
try:
|
||||
# Synthesize audio and get URL
|
||||
tts_media_id = tts_generate_media_source_id(
|
||||
self.hass,
|
||||
tts_input,
|
||||
engine=self.tts_engine,
|
||||
language=self.language,
|
||||
options=self.tts_options,
|
||||
)
|
||||
tts_media = await media_source.async_resolve_media(
|
||||
self.hass,
|
||||
tts_generate_media_source_id(
|
||||
self.hass,
|
||||
tts_input,
|
||||
engine=self.tts_engine,
|
||||
language=self.language,
|
||||
),
|
||||
tts_media_id,
|
||||
)
|
||||
except Exception as src_error:
|
||||
_LOGGER.exception("Unexpected error during text to speech")
|
||||
|
@ -406,7 +414,12 @@ class PipelineRun:
|
|||
self.event_callback(
|
||||
PipelineEvent(
|
||||
PipelineEventType.TTS_END,
|
||||
{"tts_output": asdict(tts_media)},
|
||||
{
|
||||
"tts_output": {
|
||||
"media_id": tts_media_id,
|
||||
**asdict(tts_media),
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
"""The Voice over IP integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import logging
|
||||
|
||||
from voip_utils import SIP_PORT
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_IP_ADDRESS
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .const import DOMAIN
|
||||
from .voip import HassVoipDatagramProtocol
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
_IP_WILDCARD = "0.0.0.0"
|
||||
|
||||
__all__ = [
|
||||
"DOMAIN",
|
||||
"async_setup_entry",
|
||||
"async_unload_entry",
|
||||
]
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up VoIP integration from a config entry."""
|
||||
ip_address = entry.data[CONF_IP_ADDRESS]
|
||||
_LOGGER.debug(
|
||||
"Listening for VoIP calls from %s (port=%s)",
|
||||
ip_address,
|
||||
SIP_PORT,
|
||||
)
|
||||
hass.data[DOMAIN] = await _create_sip_server(
|
||||
hass,
|
||||
lambda: HassVoipDatagramProtocol(hass, {str(ip_address)}),
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def _create_sip_server(
|
||||
hass: HomeAssistant,
|
||||
protocol_factory: Callable[
|
||||
[],
|
||||
asyncio.DatagramProtocol,
|
||||
],
|
||||
) -> asyncio.DatagramTransport:
|
||||
transport, _protocol = await hass.loop.create_datagram_endpoint(
|
||||
protocol_factory,
|
||||
local_addr=(_IP_WILDCARD, SIP_PORT),
|
||||
)
|
||||
|
||||
return transport
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload VoIP."""
|
||||
transport = hass.data.pop(DOMAIN, None)
|
||||
if transport is not None:
|
||||
transport.close()
|
||||
_LOGGER.debug("Shut down VoIP server")
|
||||
|
||||
return True
|
|
@ -0,0 +1,54 @@
|
|||
"""Config flow for VoIP integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.const import CONF_IP_ADDRESS
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
from homeassistant.util import network
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
STEP_USER_DATA_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_IP_ADDRESS): str,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for VoIP integration."""
|
||||
|
||||
VERSION = 1
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle the initial step."""
|
||||
if self._async_current_entries():
|
||||
return self.async_abort(reason="single_instance_allowed")
|
||||
|
||||
if user_input is None:
|
||||
return self.async_show_form(
|
||||
step_id="user", data_schema=STEP_USER_DATA_SCHEMA
|
||||
)
|
||||
|
||||
errors: dict = {}
|
||||
if not network.is_ipv4_address(user_input[CONF_IP_ADDRESS]):
|
||||
errors[CONF_IP_ADDRESS] = "invalid_ip_address"
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=STEP_USER_DATA_SCHEMA,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
return self.async_create_entry(
|
||||
title="Voice over IP",
|
||||
data=user_input,
|
||||
)
|
|
@ -0,0 +1,3 @@
|
|||
"""Constants for the Voice over IP integration."""
|
||||
|
||||
DOMAIN = "voip"
|
|
@ -0,0 +1,11 @@
|
|||
{
|
||||
"domain": "voip",
|
||||
"name": "Voice over IP",
|
||||
"codeowners": ["@balloob", "@synesthesiam"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["voice_assistant"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/voip",
|
||||
"iot_class": "local_push",
|
||||
"quality_scale": "internal",
|
||||
"requirements": ["voip-utils==0.0.2"]
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
{
|
||||
"config": {
|
||||
"step": {
|
||||
"user": {
|
||||
"data": {
|
||||
"ip_address": "IP Address"
|
||||
}
|
||||
}
|
||||
},
|
||||
"abort": {
|
||||
"single_instance_allowed": "[%key:common::config_flow::abort::single_instance_allowed%]"
|
||||
},
|
||||
"error": {
|
||||
"invalid_ip_address": "Invalid IPv4 address."
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,181 @@
|
|||
"""Voice over IP (VoIP) implementation."""
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
|
||||
import async_timeout
|
||||
from voip_utils import CallInfo, RtpDatagramProtocol, SdpInfo, VoipDatagramProtocol
|
||||
|
||||
from homeassistant.components import stt, tts
|
||||
from homeassistant.components.voice_assistant import (
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
PipelineEventType,
|
||||
async_pipeline_from_audio_stream,
|
||||
)
|
||||
from homeassistant.components.voice_assistant.vad import VoiceCommandSegmenter
|
||||
from homeassistant.const import __version__
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HassVoipDatagramProtocol(VoipDatagramProtocol):
|
||||
"""HA UDP server for Voice over IP (VoIP)."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, allow_ips: set[str]) -> None:
|
||||
"""Set up VoIP call handler."""
|
||||
super().__init__(
|
||||
sdp_info=SdpInfo(
|
||||
username="homeassistant",
|
||||
id=time.monotonic_ns(),
|
||||
session_name="voip_hass",
|
||||
version=__version__,
|
||||
),
|
||||
protocol_factory=lambda call_info: PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
),
|
||||
)
|
||||
self.allow_ips = allow_ips
|
||||
|
||||
def is_valid_call(self, call_info: CallInfo) -> bool:
|
||||
"""Filter calls."""
|
||||
return call_info.caller_ip in self.allow_ips
|
||||
|
||||
|
||||
class PipelineRtpDatagramProtocol(RtpDatagramProtocol):
|
||||
"""Run a voice assistant pipeline in a loop for a VoIP call."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
language: str,
|
||||
pipeline_timeout: float = 30.0,
|
||||
audio_timeout: float = 2.0,
|
||||
) -> None:
|
||||
"""Set up pipeline RTP server."""
|
||||
# STT expects 16Khz mono with 16-bit samples
|
||||
super().__init__(rate=16000, width=2, channels=1)
|
||||
|
||||
self.hass = hass
|
||||
self.language = language
|
||||
self.pipeline: Pipeline | None = None
|
||||
self.pipeline_timeout = pipeline_timeout
|
||||
self.audio_timeout = audio_timeout
|
||||
|
||||
self._audio_queue: "asyncio.Queue[bytes]" = asyncio.Queue()
|
||||
self._pipeline_task: asyncio.Task | None = None
|
||||
self._conversation_id: str | None = None
|
||||
|
||||
def connection_made(self, transport):
|
||||
"""Server is ready."""
|
||||
self.transport = transport
|
||||
|
||||
def on_chunk(self, audio_bytes: bytes) -> None:
|
||||
"""Handle raw audio chunk."""
|
||||
if self._pipeline_task is None:
|
||||
# Clear audio queue
|
||||
while not self._audio_queue.empty():
|
||||
self._audio_queue.get_nowait()
|
||||
|
||||
# Run pipeline until voice command finishes, then start over
|
||||
self._pipeline_task = self.hass.async_create_background_task(
|
||||
self._run_pipeline(),
|
||||
"voip_pipeline_run",
|
||||
)
|
||||
|
||||
self._audio_queue.put_nowait(audio_bytes)
|
||||
|
||||
async def _run_pipeline(
|
||||
self,
|
||||
) -> None:
|
||||
"""Forward audio to pipeline STT and handle TTS."""
|
||||
_LOGGER.debug("Starting pipeline")
|
||||
|
||||
async def stt_stream():
|
||||
segmenter = VoiceCommandSegmenter()
|
||||
|
||||
try:
|
||||
# Timeout if no audio comes in for a while.
|
||||
# This means the caller hung up.
|
||||
async with async_timeout.timeout(self.audio_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
|
||||
while chunk:
|
||||
if not segmenter.process(chunk):
|
||||
# Voice command is finished
|
||||
break
|
||||
|
||||
yield chunk
|
||||
|
||||
async with async_timeout.timeout(self.audio_timeout):
|
||||
chunk = await self._audio_queue.get()
|
||||
except asyncio.TimeoutError:
|
||||
# Expected after caller hangs up
|
||||
_LOGGER.debug("Audio timeout")
|
||||
|
||||
if self.transport is not None:
|
||||
self.transport.close()
|
||||
self.transport = None
|
||||
|
||||
try:
|
||||
# Run pipeline with a timeout
|
||||
async with async_timeout.timeout(self.pipeline_timeout):
|
||||
await async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
event_callback=self._event_callback,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="", # set in async_pipeline_from_audio_stream
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=stt_stream(),
|
||||
language=self.language,
|
||||
conversation_id=self._conversation_id,
|
||||
tts_options={tts.ATTR_AUDIO_OUTPUT: "raw"},
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Expected after caller hangs up
|
||||
_LOGGER.debug("Pipeline timeout")
|
||||
|
||||
if self.transport is not None:
|
||||
self.transport.close()
|
||||
self.transport = None
|
||||
finally:
|
||||
# Allow pipeline to run again
|
||||
self._pipeline_task = None
|
||||
|
||||
def _event_callback(self, event: PipelineEvent):
|
||||
if not event.data:
|
||||
return
|
||||
|
||||
if event.type == PipelineEventType.INTENT_END:
|
||||
# Capture conversation id
|
||||
self._conversation_id = event.data["intent_output"]["conversation_id"]
|
||||
elif event.type == PipelineEventType.TTS_END:
|
||||
# Send TTS audio to caller over RTP
|
||||
media_id = event.data["tts_output"]["media_id"]
|
||||
self.hass.async_create_background_task(
|
||||
self._send_media(media_id),
|
||||
"voip_pipeline_tts",
|
||||
)
|
||||
|
||||
async def _send_media(self, media_id: str) -> None:
|
||||
"""Send TTS audio to caller via RTP."""
|
||||
if self.transport is None:
|
||||
return
|
||||
|
||||
_extension, audio_bytes = await tts.async_get_media_source_audio(
|
||||
self.hass,
|
||||
media_id,
|
||||
)
|
||||
|
||||
_LOGGER.debug("Sending %s byte(s) of audio", len(audio_bytes))
|
||||
|
||||
# Assume TTS audio is 16Khz 16-bit mono
|
||||
await self.send_audio(audio_bytes, rate=16000, width=2, channels=1)
|
|
@ -481,6 +481,7 @@ FLOWS = {
|
|||
"vilfo",
|
||||
"vizio",
|
||||
"vlc_telnet",
|
||||
"voip",
|
||||
"volumio",
|
||||
"volvooncall",
|
||||
"vulcan",
|
||||
|
|
|
@ -6068,6 +6068,12 @@
|
|||
"config_flow": false,
|
||||
"iot_class": "cloud_push"
|
||||
},
|
||||
"voip": {
|
||||
"name": "Voice over IP",
|
||||
"integration_type": "hub",
|
||||
"config_flow": true,
|
||||
"iot_class": "local_push"
|
||||
},
|
||||
"volkszaehler": {
|
||||
"name": "Volkszaehler",
|
||||
"integration_type": "hub",
|
||||
|
|
|
@ -2579,6 +2579,9 @@ venstarcolortouch==0.19
|
|||
# homeassistant.components.vilfo
|
||||
vilfo-api-client==0.3.2
|
||||
|
||||
# homeassistant.components.voip
|
||||
voip-utils==0.0.2
|
||||
|
||||
# homeassistant.components.volkszaehler
|
||||
volkszaehler==0.4.0
|
||||
|
||||
|
|
|
@ -1852,6 +1852,9 @@ venstarcolortouch==0.19
|
|||
# homeassistant.components.vilfo
|
||||
vilfo-api-client==0.3.2
|
||||
|
||||
# homeassistant.components.voip
|
||||
voip-utils==0.0.2
|
||||
|
||||
# homeassistant.components.volvooncall
|
||||
volvooncall==0.10.2
|
||||
|
||||
|
|
|
@ -87,7 +87,7 @@ class MockTTSProvider(tts.Provider):
|
|||
@property
|
||||
def supported_options(self) -> list[str]:
|
||||
"""Return list of supported options like voice, emotions."""
|
||||
return ["voice", "age"]
|
||||
return ["voice", "age", tts.ATTR_AUDIO_OUTPUT]
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
|
|
|
@ -70,6 +70,7 @@
|
|||
dict({
|
||||
'data': dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||
}),
|
||||
|
|
|
@ -66,6 +66,7 @@
|
|||
# name: test_audio_pipeline.6
|
||||
dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US",
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_-_test.mp3',
|
||||
}),
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
"""Tests for the Voice over IP integration."""
|
|
@ -0,0 +1,83 @@
|
|||
"""Test VoIP config flow."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components import voip
|
||||
from homeassistant.const import CONF_IP_ADDRESS
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
async def test_form_user(hass: HomeAssistant) -> None:
|
||||
"""Test user form config flow."""
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
voip.DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] == "form"
|
||||
assert not result["errors"]
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.voip.async_setup_entry",
|
||||
return_value=True,
|
||||
) as mock_setup_entry:
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{CONF_IP_ADDRESS: "127.0.0.1"},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result["type"] == FlowResultType.CREATE_ENTRY
|
||||
assert result["data"] == {CONF_IP_ADDRESS: "127.0.0.1"}
|
||||
assert len(mock_setup_entry.mock_calls) == 1
|
||||
|
||||
|
||||
async def test_invalid_ip(hass: HomeAssistant) -> None:
|
||||
"""Test user form config flow with invalid ip address."""
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
voip.DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] == "form"
|
||||
assert not result["errors"]
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
{CONF_IP_ADDRESS: "not an ip address"},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
assert result["type"] == FlowResultType.FORM
|
||||
assert result["errors"] == {CONF_IP_ADDRESS: "invalid_ip_address"}
|
||||
|
||||
|
||||
async def test_load_unload_entry(
|
||||
hass: HomeAssistant,
|
||||
socket_enabled,
|
||||
unused_udp_port_factory,
|
||||
) -> None:
|
||||
"""Test adding/removing VoIP."""
|
||||
entry = MockConfigEntry(
|
||||
domain=voip.DOMAIN,
|
||||
data={
|
||||
CONF_IP_ADDRESS: "127.0.0.1",
|
||||
},
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.voip.SIP_PORT",
|
||||
new=unused_udp_port_factory(),
|
||||
):
|
||||
assert await voip.async_setup_entry(hass, entry)
|
||||
|
||||
# Verify single instance
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
voip.DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result["type"] == "abort"
|
||||
assert result["reason"] == "single_instance_allowed"
|
||||
|
||||
assert await voip.async_unload_entry(hass, entry)
|
|
@ -0,0 +1,171 @@
|
|||
"""Test VoIP protocol."""
|
||||
import asyncio
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import async_timeout
|
||||
|
||||
from homeassistant.components import voice_assistant, voip
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
|
||||
_MEDIA_ID = "12345"
|
||||
|
||||
|
||||
async def test_pipeline(hass: HomeAssistant) -> None:
|
||||
"""Test that pipeline function is called from RTP protocol."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
def is_speech(self, chunk, sample_rate):
|
||||
"""Anything non-zero is speech."""
|
||||
return sum(chunk) > 0
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
# Used to test that audio queue is cleared before pipeline starts
|
||||
bad_chunk = bytes([1, 2, 3, 4])
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
event_callback = kwargs["event_callback"]
|
||||
async for _chunk in stt_stream:
|
||||
# Stream will end when VAD detects end of "speech"
|
||||
assert _chunk != bad_chunk
|
||||
pass
|
||||
|
||||
# Test empty data
|
||||
event_callback(
|
||||
voice_assistant.PipelineEvent(
|
||||
type="not-used",
|
||||
data={},
|
||||
)
|
||||
)
|
||||
|
||||
# Fake intent result
|
||||
event_callback(
|
||||
voice_assistant.PipelineEvent(
|
||||
type=voice_assistant.PipelineEventType.INTENT_END,
|
||||
data={
|
||||
"intent_output": {
|
||||
"conversation_id": "fake-conversation",
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Proceed with media output
|
||||
event_callback(
|
||||
voice_assistant.PipelineEvent(
|
||||
type=voice_assistant.PipelineEventType.TTS_END,
|
||||
data={"tts_output": {"media_id": _MEDIA_ID}},
|
||||
)
|
||||
)
|
||||
|
||||
async def async_get_media_source_audio(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
assert media_source_id == _MEDIA_ID
|
||||
|
||||
return ("mp3", b"")
|
||||
|
||||
with patch(
|
||||
"webrtcvad.Vad.is_speech",
|
||||
new=is_speech,
|
||||
), patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
), patch(
|
||||
"homeassistant.components.voip.voip.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass,
|
||||
hass.config.language,
|
||||
)
|
||||
rtp_protocol.transport = Mock()
|
||||
|
||||
# Ensure audio queue is cleared before pipeline starts
|
||||
rtp_protocol._audio_queue.put_nowait(bad_chunk)
|
||||
|
||||
async def send_audio(*args, **kwargs):
|
||||
# Test finished successfully
|
||||
done.set()
|
||||
|
||||
rtp_protocol.send_audio = Mock(side_effect=send_audio)
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# "speech"
|
||||
rtp_protocol.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for mock pipeline to exhaust the audio stream
|
||||
async with async_timeout.timeout(1):
|
||||
await done.wait()
|
||||
|
||||
|
||||
async def test_pipeline_timeout(hass: HomeAssistant) -> None:
|
||||
"""Test timeout during pipeline run."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
await asyncio.sleep(10)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass, hass.config.language, pipeline_timeout=0.001
|
||||
)
|
||||
transport = Mock(spec=["close"])
|
||||
rtp_protocol.connection_made(transport)
|
||||
|
||||
# Closing the transport will cause the test to succeed
|
||||
transport.close.side_effect = done.set
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for mock pipeline to time out
|
||||
async with async_timeout.timeout(1):
|
||||
await done.wait()
|
||||
|
||||
|
||||
async def test_stt_stream_timeout(hass: HomeAssistant) -> None:
|
||||
"""Test timeout in STT stream during pipeline run."""
|
||||
assert await async_setup_component(hass, "voip", {})
|
||||
|
||||
done = asyncio.Event()
|
||||
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
stt_stream = kwargs["stt_stream"]
|
||||
async for _chunk in stt_stream:
|
||||
# Iterate over stream
|
||||
pass
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.voip.voip.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
):
|
||||
rtp_protocol = voip.voip.PipelineRtpDatagramProtocol(
|
||||
hass, hass.config.language, audio_timeout=0.001
|
||||
)
|
||||
transport = Mock(spec=["close"])
|
||||
rtp_protocol.connection_made(transport)
|
||||
|
||||
# Closing the transport will cause the test to succeed
|
||||
transport.close.side_effect = done.set
|
||||
|
||||
# silence
|
||||
rtp_protocol.on_chunk(bytes(_ONE_SECOND))
|
||||
|
||||
# Wait for mock pipeline to time out
|
||||
async with async_timeout.timeout(1):
|
||||
await done.wait()
|
Loading…
Reference in New Issue