Send/receive Voice Assistant audio via ESPHome native API (#114800)

* Protobuf audio test

* Remove extraneous code

* Rework voice assistant pipeline

* Move variables

* Fix reading flags

* Dont directly put to queue

* Bump aioesphomeapi to 24.0.0

* Update tests

- Add more tests for API pipeline
- Convert some udp tests to use api pipeline
- Update fixtures for new device info flags

* Fix bad merge

---------

Co-authored-by: Michael Hansen <mike@rhasspy.org>
pull/115322/head
Jesse Hills 2024-04-10 02:55:59 +12:00 committed by GitHub
parent cad4c3c0c2
commit 68384bba67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 495 additions and 230 deletions

View File

@ -33,7 +33,9 @@ async def async_setup_entry(
entry_data = DomainData.get(hass).get_entry_data(entry) entry_data = DomainData.get(hass).get_entry_data(entry)
assert entry_data.device_info is not None assert entry_data.device_info is not None
if entry_data.device_info.voice_assistant_version: if entry_data.device_info.voice_assistant_feature_flags_compat(
entry_data.api_version
):
async_add_entities([EsphomeAssistInProgressBinarySensor(entry_data)]) async_add_entities([EsphomeAssistInProgressBinarySensor(entry_data)])

View File

@ -257,7 +257,9 @@ class RuntimeEntryData:
if async_get_dashboard(hass): if async_get_dashboard(hass):
needed_platforms.add(Platform.UPDATE) needed_platforms.add(Platform.UPDATE)
if self.device_info and self.device_info.voice_assistant_version: if self.device_info and self.device_info.voice_assistant_feature_flags_compat(
self.api_version
):
needed_platforms.add(Platform.BINARY_SENSOR) needed_platforms.add(Platform.BINARY_SENSOR)
needed_platforms.add(Platform.SELECT) needed_platforms.add(Platform.SELECT)

View File

@ -21,6 +21,7 @@ from aioesphomeapi import (
UserService, UserService,
UserServiceArgType, UserServiceArgType,
VoiceAssistantAudioSettings, VoiceAssistantAudioSettings,
VoiceAssistantFeature,
) )
from awesomeversion import AwesomeVersion from awesomeversion import AwesomeVersion
import voluptuous as vol import voluptuous as vol
@ -72,7 +73,11 @@ from .domain_data import DomainData
# Import config flow so that it's added to the registry # Import config flow so that it's added to the registry
from .entry_data import RuntimeEntryData from .entry_data import RuntimeEntryData
from .voice_assistant import VoiceAssistantUDPServer from .voice_assistant import (
VoiceAssistantAPIPipeline,
VoiceAssistantPipeline,
VoiceAssistantUDPPipeline,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -143,7 +148,7 @@ class ESPHomeManager:
"cli", "cli",
"device_id", "device_id",
"domain_data", "domain_data",
"voice_assistant_udp_server", "voice_assistant_pipeline",
"reconnect_logic", "reconnect_logic",
"zeroconf_instance", "zeroconf_instance",
"entry_data", "entry_data",
@ -168,7 +173,7 @@ class ESPHomeManager:
self.cli = cli self.cli = cli
self.device_id: str | None = None self.device_id: str | None = None
self.domain_data = domain_data self.domain_data = domain_data
self.voice_assistant_udp_server: VoiceAssistantUDPServer | None = None self.voice_assistant_pipeline: VoiceAssistantPipeline | None = None
self.reconnect_logic: ReconnectLogic | None = None self.reconnect_logic: ReconnectLogic | None = None
self.zeroconf_instance = zeroconf_instance self.zeroconf_instance = zeroconf_instance
self.entry_data = entry_data self.entry_data = entry_data
@ -327,9 +332,10 @@ class ESPHomeManager:
def _handle_pipeline_finished(self) -> None: def _handle_pipeline_finished(self) -> None:
self.entry_data.async_set_assist_pipeline_state(False) self.entry_data.async_set_assist_pipeline_state(False)
if self.voice_assistant_udp_server is not None: if self.voice_assistant_pipeline is not None:
self.voice_assistant_udp_server.close() if isinstance(self.voice_assistant_pipeline, VoiceAssistantUDPPipeline):
self.voice_assistant_udp_server = None self.voice_assistant_pipeline.close()
self.voice_assistant_pipeline = None
async def _handle_pipeline_start( async def _handle_pipeline_start(
self, self,
@ -339,38 +345,60 @@ class ESPHomeManager:
wake_word_phrase: str | None, wake_word_phrase: str | None,
) -> int | None: ) -> int | None:
"""Start a voice assistant pipeline.""" """Start a voice assistant pipeline."""
if self.voice_assistant_udp_server is not None: if self.voice_assistant_pipeline is not None:
_LOGGER.warning("Voice assistant UDP server was not stopped") _LOGGER.warning("Voice assistant UDP server was not stopped")
self.voice_assistant_udp_server.stop() self.voice_assistant_pipeline.stop()
self.voice_assistant_udp_server = None self.voice_assistant_pipeline = None
hass = self.hass hass = self.hass
self.voice_assistant_udp_server = VoiceAssistantUDPServer( assert self.entry_data.device_info is not None
hass, if (
self.entry_data, self.entry_data.device_info.voice_assistant_feature_flags_compat(
self.cli.send_voice_assistant_event, self.entry_data.api_version
self._handle_pipeline_finished, )
) & VoiceAssistantFeature.API_AUDIO
port = await self.voice_assistant_udp_server.start_server() ):
self.voice_assistant_pipeline = VoiceAssistantAPIPipeline(
hass,
self.entry_data,
self.cli.send_voice_assistant_event,
self._handle_pipeline_finished,
self.cli,
)
port = 0
else:
self.voice_assistant_pipeline = VoiceAssistantUDPPipeline(
hass,
self.entry_data,
self.cli.send_voice_assistant_event,
self._handle_pipeline_finished,
)
port = await self.voice_assistant_pipeline.start_server()
assert self.device_id is not None, "Device ID must be set" assert self.device_id is not None, "Device ID must be set"
hass.async_create_background_task( hass.async_create_background_task(
self.voice_assistant_udp_server.run_pipeline( self.voice_assistant_pipeline.run_pipeline(
device_id=self.device_id, device_id=self.device_id,
conversation_id=conversation_id or None, conversation_id=conversation_id or None,
flags=flags, flags=flags,
audio_settings=audio_settings, audio_settings=audio_settings,
wake_word_phrase=wake_word_phrase, wake_word_phrase=wake_word_phrase,
), ),
"esphome.voice_assistant_udp_server.run_pipeline", "esphome.voice_assistant_pipeline.run_pipeline",
) )
return port return port
async def _handle_pipeline_stop(self) -> None: async def _handle_pipeline_stop(self) -> None:
"""Stop a voice assistant pipeline.""" """Stop a voice assistant pipeline."""
if self.voice_assistant_udp_server is not None: if self.voice_assistant_pipeline is not None:
self.voice_assistant_udp_server.stop() self.voice_assistant_pipeline.stop()
async def _handle_audio(self, data: bytes) -> None:
if self.voice_assistant_pipeline is None:
return
assert isinstance(self.voice_assistant_pipeline, VoiceAssistantAPIPipeline)
self.voice_assistant_pipeline.receive_audio_bytes(data)
async def on_connect(self) -> None: async def on_connect(self) -> None:
"""Subscribe to states and list entities on successful API login.""" """Subscribe to states and list entities on successful API login."""
@ -472,13 +500,23 @@ class ESPHomeManager:
) )
) )
if device_info.voice_assistant_version: flags = device_info.voice_assistant_feature_flags_compat(api_version)
entry_data.disconnect_callbacks.add( if flags:
cli.subscribe_voice_assistant( if flags & VoiceAssistantFeature.API_AUDIO:
self._handle_pipeline_start, entry_data.disconnect_callbacks.add(
self._handle_pipeline_stop, cli.subscribe_voice_assistant(
handle_start=self._handle_pipeline_start,
handle_stop=self._handle_pipeline_stop,
handle_audio=self._handle_audio,
)
)
else:
entry_data.disconnect_callbacks.add(
cli.subscribe_voice_assistant(
handle_start=self._handle_pipeline_start,
handle_stop=self._handle_pipeline_stop,
)
) )
)
cli.subscribe_states(entry_data.async_update_state) cli.subscribe_states(entry_data.async_update_state)
cli.subscribe_service_calls(self.async_on_service_call) cli.subscribe_service_calls(self.async_on_service_call)

View File

@ -15,7 +15,7 @@
"iot_class": "local_push", "iot_class": "local_push",
"loggers": ["aioesphomeapi", "noiseprotocol", "bleak_esphome"], "loggers": ["aioesphomeapi", "noiseprotocol", "bleak_esphome"],
"requirements": [ "requirements": [
"aioesphomeapi==23.2.0", "aioesphomeapi==24.0.0",
"esphome-dashboard-api==1.2.3", "esphome-dashboard-api==1.2.3",
"bleak-esphome==1.0.0" "bleak-esphome==1.0.0"
], ],

View File

@ -42,7 +42,9 @@ async def async_setup_entry(
entry_data = DomainData.get(hass).get_entry_data(entry) entry_data = DomainData.get(hass).get_entry_data(entry)
assert entry_data.device_info is not None assert entry_data.device_info is not None
if entry_data.device_info.voice_assistant_version: if entry_data.device_info.voice_assistant_feature_flags_compat(
entry_data.api_version
):
async_add_entities( async_add_entities(
[ [
EsphomeAssistPipelineSelect(hass, entry_data), EsphomeAssistPipelineSelect(hass, entry_data),

View File

@ -11,9 +11,11 @@ from typing import cast
import wave import wave
from aioesphomeapi import ( from aioesphomeapi import (
APIClient,
VoiceAssistantAudioSettings, VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag, VoiceAssistantCommandFlag,
VoiceAssistantEventType, VoiceAssistantEventType,
VoiceAssistantFeature,
) )
from homeassistant.components import stt, tts from homeassistant.components import stt, tts
@ -64,13 +66,11 @@ _VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
) )
class VoiceAssistantUDPServer(asyncio.DatagramProtocol): class VoiceAssistantPipeline:
"""Receive UDP packets and forward them to the voice assistant.""" """Base abstract pipeline class."""
started = False started = False
stop_requested = False stop_requested = False
transport: asyncio.DatagramTransport | None = None
remote_addr: tuple[str, int] | None = None
def __init__( def __init__(
self, self,
@ -79,12 +79,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None], handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
handle_finished: Callable[[], None], handle_finished: Callable[[], None],
) -> None: ) -> None:
"""Initialize UDP receiver.""" """Initialize the pipeline."""
self.context = Context() self.context = Context()
self.hass = hass self.hass = hass
assert entry_data.device_info is not None
self.entry_data = entry_data self.entry_data = entry_data
assert entry_data.device_info is not None
self.device_info = entry_data.device_info self.device_info = entry_data.device_info
self.queue: asyncio.Queue[bytes] = asyncio.Queue() self.queue: asyncio.Queue[bytes] = asyncio.Queue()
@ -95,69 +94,9 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
"""True if the UDP server is started and hasn't been asked to stop.""" """True if the pipeline is started and hasn't been asked to stop."""
return self.started and (not self.stop_requested) return self.started and (not self.stop_requested)
async def start_server(self) -> int:
"""Start accepting connections."""
def accept_connection() -> VoiceAssistantUDPServer:
"""Accept connection."""
if self.started:
raise RuntimeError("Can only start once")
if self.stop_requested:
raise RuntimeError("No longer accepting connections")
self.started = True
return self
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(("", UDP_PORT))
await asyncio.get_running_loop().create_datagram_endpoint(
accept_connection, sock=sock
)
return cast(int, sock.getsockname()[1])
@callback
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Store transport for later use."""
self.transport = cast(asyncio.DatagramTransport, transport)
@callback
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
"""Handle incoming UDP packet."""
if not self.is_running:
return
if self.remote_addr is None:
self.remote_addr = addr
self.queue.put_nowait(data)
def error_received(self, exc: Exception) -> None:
"""Handle when a send or receive operation raises an OSError.
(Other than BlockingIOError or InterruptedError.)
"""
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
self.handle_finished()
@callback
def stop(self) -> None:
"""Stop the receiver."""
self.queue.put_nowait(b"")
self.close()
def close(self) -> None:
"""Close the receiver."""
self.started = False
self.stop_requested = True
if self.transport is not None:
self.transport.close()
async def _iterate_packets(self) -> AsyncIterable[bytes]: async def _iterate_packets(self) -> AsyncIterable[bytes]:
"""Iterate over incoming packets.""" """Iterate over incoming packets."""
while data := await self.queue.get(): while data := await self.queue.get():
@ -198,7 +137,12 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
url = async_process_play_media_url(self.hass, path) url = async_process_play_media_url(self.hass, path)
data_to_send = {"url": url} data_to_send = {"url": url}
if self.device_info.voice_assistant_version >= 2: if (
self.device_info.voice_assistant_feature_flags_compat(
self.entry_data.api_version
)
& VoiceAssistantFeature.SPEAKER
):
media_id = tts_output["media_id"] media_id = tts_output["media_id"]
self._tts_task = self.hass.async_create_background_task( self._tts_task = self.hass.async_create_background_task(
self._send_tts(media_id), "esphome_voice_assistant_tts" self._send_tts(media_id), "esphome_voice_assistant_tts"
@ -243,9 +187,15 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
if audio_settings is None or audio_settings.volume_multiplier == 0: if audio_settings is None or audio_settings.volume_multiplier == 0:
audio_settings = VoiceAssistantAudioSettings() audio_settings = VoiceAssistantAudioSettings()
tts_audio_output = ( if (
"wav" if self.device_info.voice_assistant_version >= 2 else "mp3" self.device_info.voice_assistant_feature_flags_compat(
) self.entry_data.api_version
)
& VoiceAssistantFeature.SPEAKER
):
tts_audio_output = "wav"
else:
tts_audio_output = "mp3"
_LOGGER.debug("Starting pipeline") _LOGGER.debug("Starting pipeline")
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD: if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
@ -315,7 +265,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
self.handle_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}) self.handle_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {})
try: try:
if (not self.is_running) or (self.transport is None): if not self.is_running:
return return
extension, data = await tts.async_get_media_source_audio( extension, data = await tts.async_get_media_source_audio(
@ -358,16 +308,133 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
samples_in_chunk = len(chunk) // bytes_per_sample samples_in_chunk = len(chunk) // bytes_per_sample
samples_left -= samples_in_chunk samples_left -= samples_in_chunk
self.transport.sendto(chunk, self.remote_addr) self.send_audio_bytes(chunk)
await asyncio.sleep( await asyncio.sleep(
samples_in_chunk / stt.AudioSampleRates.SAMPLERATE_16000 * 0.9 samples_in_chunk / stt.AudioSampleRates.SAMPLERATE_16000 * 0.9
) )
sample_offset += samples_in_chunk sample_offset += samples_in_chunk
finally: finally:
self.handle_event( self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {} VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}
) )
self._tts_task = None self._tts_task = None
self._tts_done.set() self._tts_done.set()
def send_audio_bytes(self, data: bytes) -> None:
"""Send bytes to the device."""
raise NotImplementedError
def stop(self) -> None:
"""Stop the pipeline."""
self.queue.put_nowait(b"")
class VoiceAssistantUDPPipeline(asyncio.DatagramProtocol, VoiceAssistantPipeline):
"""Receive UDP packets and forward them to the voice assistant."""
transport: asyncio.DatagramTransport | None = None
remote_addr: tuple[str, int] | None = None
async def start_server(self) -> int:
"""Start accepting connections."""
def accept_connection() -> VoiceAssistantUDPPipeline:
"""Accept connection."""
if self.started:
raise RuntimeError("Can only start once")
if self.stop_requested:
raise RuntimeError("No longer accepting connections")
self.started = True
return self
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(("", UDP_PORT))
await asyncio.get_running_loop().create_datagram_endpoint(
accept_connection, sock=sock
)
return cast(int, sock.getsockname()[1])
@callback
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Store transport for later use."""
self.transport = cast(asyncio.DatagramTransport, transport)
@callback
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
"""Handle incoming UDP packet."""
if not self.is_running:
return
if self.remote_addr is None:
self.remote_addr = addr
self.queue.put_nowait(data)
def error_received(self, exc: Exception) -> None:
"""Handle when a send or receive operation raises an OSError.
(Other than BlockingIOError or InterruptedError.)
"""
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
self.handle_finished()
@callback
def stop(self) -> None:
"""Stop the receiver."""
super().stop()
self.close()
def close(self) -> None:
"""Close the receiver."""
self.started = False
self.stop_requested = True
if self.transport is not None:
self.transport.close()
def send_audio_bytes(self, data: bytes) -> None:
"""Send bytes to the device via UDP."""
if self.transport is None:
_LOGGER.error("No transport to send audio to")
return
self.transport.sendto(data, self.remote_addr)
class VoiceAssistantAPIPipeline(VoiceAssistantPipeline):
"""Send audio to the voice assistant via the API."""
def __init__(
self,
hass: HomeAssistant,
entry_data: RuntimeEntryData,
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
handle_finished: Callable[[], None],
api_client: APIClient,
) -> None:
"""Initialize the pipeline."""
super().__init__(hass, entry_data, handle_event, handle_finished)
self.api_client = api_client
self.started = True
def send_audio_bytes(self, data: bytes) -> None:
"""Send bytes to the device via the API."""
self.api_client.send_voice_assistant_audio(data)
@callback
def receive_audio_bytes(self, data: bytes) -> None:
"""Receive audio bytes from the device."""
if not self.is_running:
return
self.queue.put_nowait(data)
@callback
def stop(self) -> None:
"""Stop the pipeline."""
super().stop()
self.started = False
self.stop_requested = True

View File

@ -242,7 +242,7 @@ aioelectricitymaps==0.4.0
aioemonitor==1.0.5 aioemonitor==1.0.5
# homeassistant.components.esphome # homeassistant.components.esphome
aioesphomeapi==23.2.0 aioesphomeapi==24.0.0
# homeassistant.components.flo # homeassistant.components.flo
aioflo==2021.11.0 aioflo==2021.11.0

View File

@ -221,7 +221,7 @@ aioelectricitymaps==0.4.0
aioemonitor==1.0.5 aioemonitor==1.0.5
# homeassistant.components.esphome # homeassistant.components.esphome
aioesphomeapi==23.2.0 aioesphomeapi==24.0.0
# homeassistant.components.flo # homeassistant.components.flo
aioflo==2021.11.0 aioflo==2021.11.0

View File

@ -18,6 +18,7 @@ from aioesphomeapi import (
HomeassistantServiceCall, HomeassistantServiceCall,
ReconnectLogic, ReconnectLogic,
UserService, UserService,
VoiceAssistantFeature,
) )
import pytest import pytest
from zeroconf import Zeroconf from zeroconf import Zeroconf
@ -354,10 +355,16 @@ async def mock_voice_assistant_entry(
): ):
"""Set up an ESPHome entry with voice assistant.""" """Set up an ESPHome entry with voice assistant."""
async def _mock_voice_assistant_entry(version: int) -> MockConfigEntry: async def _mock_voice_assistant_entry(
voice_assistant_feature_flags: VoiceAssistantFeature,
) -> MockConfigEntry:
return ( return (
await _mock_generic_device_entry( await _mock_generic_device_entry(
hass, mock_client, {"voice_assistant_version": version}, ([], []), [] hass,
mock_client,
{"voice_assistant_feature_flags": voice_assistant_feature_flags},
([], []),
[],
) )
).entry ).entry
@ -367,13 +374,28 @@ async def mock_voice_assistant_entry(
@pytest.fixture @pytest.fixture
async def mock_voice_assistant_v1_entry(mock_voice_assistant_entry) -> MockConfigEntry: async def mock_voice_assistant_v1_entry(mock_voice_assistant_entry) -> MockConfigEntry:
"""Set up an ESPHome entry with voice assistant.""" """Set up an ESPHome entry with voice assistant."""
return await mock_voice_assistant_entry(version=1) return await mock_voice_assistant_entry(
voice_assistant_feature_flags=VoiceAssistantFeature.VOICE_ASSISTANT
)
@pytest.fixture @pytest.fixture
async def mock_voice_assistant_v2_entry(mock_voice_assistant_entry) -> MockConfigEntry: async def mock_voice_assistant_v2_entry(mock_voice_assistant_entry) -> MockConfigEntry:
"""Set up an ESPHome entry with voice assistant.""" """Set up an ESPHome entry with voice assistant."""
return await mock_voice_assistant_entry(version=2) return await mock_voice_assistant_entry(
voice_assistant_feature_flags=VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.SPEAKER
)
@pytest.fixture
async def mock_voice_assistant_api_entry(mock_voice_assistant_entry) -> MockConfigEntry:
"""Set up an ESPHome entry with voice assistant."""
return await mock_voice_assistant_entry(
voice_assistant_feature_flags=VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.SPEAKER
| VoiceAssistantFeature.API_AUDIO
)
@pytest.fixture @pytest.fixture

View File

@ -94,7 +94,8 @@ async def test_diagnostics_with_bluetooth(
"project_version": "", "project_version": "",
"suggested_area": "", "suggested_area": "",
"uses_password": False, "uses_password": False,
"voice_assistant_version": 0, "legacy_voice_assistant_version": 0,
"voice_assistant_feature_flags": 0,
"webserver_port": 0, "webserver_port": 0,
}, },
"services": [], "services": [],

View File

@ -6,7 +6,7 @@ import socket
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import wave import wave
from aioesphomeapi import VoiceAssistantEventType from aioesphomeapi import APIClient, VoiceAssistantEventType
import pytest import pytest
from homeassistant.components.assist_pipeline import ( from homeassistant.components.assist_pipeline import (
@ -19,7 +19,10 @@ from homeassistant.components.assist_pipeline.error import (
WakeWordDetectionError, WakeWordDetectionError,
) )
from homeassistant.components.esphome import DomainData from homeassistant.components.esphome import DomainData
from homeassistant.components.esphome.voice_assistant import VoiceAssistantUDPServer from homeassistant.components.esphome.voice_assistant import (
VoiceAssistantAPIPipeline,
VoiceAssistantUDPPipeline,
)
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
_TEST_INPUT_TEXT = "This is an input test" _TEST_INPUT_TEXT = "This is an input test"
@ -31,43 +34,54 @@ _ONE_SECOND = 16000 * 2 # 16Khz 16-bit
@pytest.fixture @pytest.fixture
def voice_assistant_udp_server( def voice_assistant_udp_pipeline(
hass: HomeAssistant, hass: HomeAssistant,
) -> VoiceAssistantUDPServer: ) -> VoiceAssistantUDPPipeline:
"""Return the UDP server factory.""" """Return the UDP pipeline factory."""
def _voice_assistant_udp_server(entry): def _voice_assistant_udp_server(entry):
entry_data = DomainData.get(hass).get_entry_data(entry) entry_data = DomainData.get(hass).get_entry_data(entry)
server: VoiceAssistantUDPServer = None server: VoiceAssistantUDPPipeline = None
def handle_finished(): def handle_finished():
nonlocal server nonlocal server
assert server is not None assert server is not None
server.close() server.close()
server = VoiceAssistantUDPServer(hass, entry_data, Mock(), handle_finished) server = VoiceAssistantUDPPipeline(hass, entry_data, Mock(), handle_finished)
return server # noqa: RET504 return server # noqa: RET504
return _voice_assistant_udp_server return _voice_assistant_udp_server
@pytest.fixture @pytest.fixture
def voice_assistant_udp_server_v1( def voice_assistant_api_pipeline(
voice_assistant_udp_server, hass: HomeAssistant,
mock_voice_assistant_v1_entry, mock_client,
) -> VoiceAssistantUDPServer: mock_voice_assistant_api_entry,
"""Return the UDP server.""" ) -> VoiceAssistantAPIPipeline:
return voice_assistant_udp_server(entry=mock_voice_assistant_v1_entry) """Return the API Pipeline factory."""
entry_data = DomainData.get(hass).get_entry_data(mock_voice_assistant_api_entry)
return VoiceAssistantAPIPipeline(hass, entry_data, Mock(), Mock(), mock_client)
@pytest.fixture @pytest.fixture
def voice_assistant_udp_server_v2( def voice_assistant_udp_pipeline_v1(
voice_assistant_udp_server, voice_assistant_udp_pipeline,
mock_voice_assistant_v1_entry,
) -> VoiceAssistantUDPPipeline:
"""Return the UDP pipeline."""
return voice_assistant_udp_pipeline(entry=mock_voice_assistant_v1_entry)
@pytest.fixture
def voice_assistant_udp_pipeline_v2(
voice_assistant_udp_pipeline,
mock_voice_assistant_v2_entry, mock_voice_assistant_v2_entry,
) -> VoiceAssistantUDPServer: ) -> VoiceAssistantUDPPipeline:
"""Return the UDP server.""" """Return the UDP pipeline."""
return voice_assistant_udp_server(entry=mock_voice_assistant_v2_entry) return voice_assistant_udp_pipeline(entry=mock_voice_assistant_v2_entry)
@pytest.fixture @pytest.fixture
@ -85,7 +99,7 @@ def test_wav() -> bytes:
async def test_pipeline_events( async def test_pipeline_events(
hass: HomeAssistant, hass: HomeAssistant,
voice_assistant_udp_server_v1: VoiceAssistantUDPServer, voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
) -> None: ) -> None:
"""Test that the pipeline function is called.""" """Test that the pipeline function is called."""
@ -145,15 +159,15 @@ async def test_pipeline_events(
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
assert data is None assert data is None
voice_assistant_udp_server_v1.handle_event = handle_event voice_assistant_udp_pipeline_v1.handle_event = handle_event
with patch( with patch(
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream, new=async_pipeline_from_audio_stream,
): ):
voice_assistant_udp_server_v1.transport = Mock() voice_assistant_udp_pipeline_v1.transport = Mock()
await voice_assistant_udp_server_v1.run_pipeline( await voice_assistant_udp_pipeline_v1.run_pipeline(
device_id="mock-device-id", conversation_id=None device_id="mock-device-id", conversation_id=None
) )
@ -162,7 +176,7 @@ async def test_udp_server(
hass: HomeAssistant, hass: HomeAssistant,
socket_enabled, socket_enabled,
unused_udp_port_factory, unused_udp_port_factory,
voice_assistant_udp_server_v1: VoiceAssistantUDPServer, voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
) -> None: ) -> None:
"""Test the UDP server runs and queues incoming data.""" """Test the UDP server runs and queues incoming data."""
port_to_use = unused_udp_port_factory() port_to_use = unused_udp_port_factory()
@ -170,93 +184,133 @@ async def test_udp_server(
with patch( with patch(
"homeassistant.components.esphome.voice_assistant.UDP_PORT", new=port_to_use "homeassistant.components.esphome.voice_assistant.UDP_PORT", new=port_to_use
): ):
port = await voice_assistant_udp_server_v1.start_server() port = await voice_assistant_udp_pipeline_v1.start_server()
assert port == port_to_use assert port == port_to_use
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
assert voice_assistant_udp_server_v1.queue.qsize() == 0 assert voice_assistant_udp_pipeline_v1.queue.qsize() == 0
sock.sendto(b"test", ("127.0.0.1", port)) sock.sendto(b"test", ("127.0.0.1", port))
# Give the socket some time to send/receive the data # Give the socket some time to send/receive the data
async with asyncio.timeout(1): async with asyncio.timeout(1):
while voice_assistant_udp_server_v1.queue.qsize() == 0: while voice_assistant_udp_pipeline_v1.queue.qsize() == 0:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
assert voice_assistant_udp_server_v1.queue.qsize() == 1 assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1
voice_assistant_udp_server_v1.stop() voice_assistant_udp_pipeline_v1.stop()
voice_assistant_udp_server_v1.close() voice_assistant_udp_pipeline_v1.close()
assert voice_assistant_udp_server_v1.transport.is_closing() assert voice_assistant_udp_pipeline_v1.transport.is_closing()
async def test_udp_server_queue( async def test_udp_server_queue(
hass: HomeAssistant, hass: HomeAssistant,
voice_assistant_udp_server_v1: VoiceAssistantUDPServer, voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
) -> None: ) -> None:
"""Test the UDP server queues incoming data.""" """Test the UDP server queues incoming data."""
voice_assistant_udp_server_v1.started = True voice_assistant_udp_pipeline_v1.started = True
assert voice_assistant_udp_server_v1.queue.qsize() == 0 assert voice_assistant_udp_pipeline_v1.queue.qsize() == 0
voice_assistant_udp_server_v1.datagram_received(bytes(1024), ("localhost", 0)) voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0))
assert voice_assistant_udp_server_v1.queue.qsize() == 1 assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1
voice_assistant_udp_server_v1.datagram_received(bytes(1024), ("localhost", 0)) voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0))
assert voice_assistant_udp_server_v1.queue.qsize() == 2 assert voice_assistant_udp_pipeline_v1.queue.qsize() == 2
async for data in voice_assistant_udp_server_v1._iterate_packets(): async for data in voice_assistant_udp_pipeline_v1._iterate_packets():
assert data == bytes(1024) assert data == bytes(1024)
break break
assert voice_assistant_udp_server_v1.queue.qsize() == 1 # One message removed assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1 # One message removed
voice_assistant_udp_server_v1.stop() voice_assistant_udp_pipeline_v1.stop()
assert ( assert (
voice_assistant_udp_server_v1.queue.qsize() == 2 voice_assistant_udp_pipeline_v1.queue.qsize() == 2
) # An empty message added by stop ) # An empty message added by stop
voice_assistant_udp_server_v1.datagram_received(bytes(1024), ("localhost", 0)) voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0))
assert ( assert (
voice_assistant_udp_server_v1.queue.qsize() == 2 voice_assistant_udp_pipeline_v1.queue.qsize() == 2
) # No new messages added after stop ) # No new messages added after stop
voice_assistant_udp_server_v1.close() voice_assistant_udp_pipeline_v1.close()
# Stopping the UDP server should cause _iterate_packets to break out # Stopping the UDP server should cause _iterate_packets to break out
# immediately without yielding any data. # immediately without yielding any data.
has_data = False has_data = False
async for _data in voice_assistant_udp_server_v1._iterate_packets(): async for _data in voice_assistant_udp_pipeline_v1._iterate_packets():
has_data = True has_data = True
assert not has_data, "Server was stopped" assert not has_data, "Server was stopped"
async def test_api_pipeline_queue(
hass: HomeAssistant,
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None:
"""Test the API pipeline queues incoming data."""
voice_assistant_api_pipeline.started = True
assert voice_assistant_api_pipeline.queue.qsize() == 0
voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024))
assert voice_assistant_api_pipeline.queue.qsize() == 1
voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024))
assert voice_assistant_api_pipeline.queue.qsize() == 2
async for data in voice_assistant_api_pipeline._iterate_packets():
assert data == bytes(1024)
break
assert voice_assistant_api_pipeline.queue.qsize() == 1 # One message removed
voice_assistant_api_pipeline.stop()
assert (
voice_assistant_api_pipeline.queue.qsize() == 2
) # An empty message added by stop
voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024))
assert (
voice_assistant_api_pipeline.queue.qsize() == 2
) # No new messages added after stop
# Stopping the API Pipeline should cause _iterate_packets to break out
# immediately without yielding any data.
has_data = False
async for _data in voice_assistant_api_pipeline._iterate_packets():
has_data = True
assert not has_data, "Pipeline was stopped"
async def test_error_calls_handle_finished( async def test_error_calls_handle_finished(
hass: HomeAssistant, hass: HomeAssistant,
voice_assistant_udp_server_v1: VoiceAssistantUDPServer, voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
) -> None: ) -> None:
"""Test that the handle_finished callback is called when an error occurs.""" """Test that the handle_finished callback is called when an error occurs."""
voice_assistant_udp_server_v1.handle_finished = Mock() voice_assistant_udp_pipeline_v1.handle_finished = Mock()
voice_assistant_udp_server_v1.error_received(Exception()) voice_assistant_udp_pipeline_v1.error_received(Exception())
voice_assistant_udp_server_v1.handle_finished.assert_called() voice_assistant_udp_pipeline_v1.handle_finished.assert_called()
async def test_udp_server_multiple( async def test_udp_server_multiple(
hass: HomeAssistant, hass: HomeAssistant,
socket_enabled, socket_enabled,
unused_udp_port_factory, unused_udp_port_factory,
voice_assistant_udp_server_v1: VoiceAssistantUDPServer, voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
) -> None: ) -> None:
"""Test that the UDP server raises an error if started twice.""" """Test that the UDP server raises an error if started twice."""
with patch( with patch(
"homeassistant.components.esphome.voice_assistant.UDP_PORT", "homeassistant.components.esphome.voice_assistant.UDP_PORT",
new=unused_udp_port_factory(), new=unused_udp_port_factory(),
): ):
await voice_assistant_udp_server_v1.start_server() await voice_assistant_udp_pipeline_v1.start_server()
with ( with (
patch( patch(
@ -265,17 +319,17 @@ async def test_udp_server_multiple(
), ),
pytest.raises(RuntimeError), pytest.raises(RuntimeError),
): ):
await voice_assistant_udp_server_v1.start_server() await voice_assistant_udp_pipeline_v1.start_server()
async def test_udp_server_after_stopped( async def test_udp_server_after_stopped(
hass: HomeAssistant, hass: HomeAssistant,
socket_enabled, socket_enabled,
unused_udp_port_factory, unused_udp_port_factory,
voice_assistant_udp_server_v1: VoiceAssistantUDPServer, voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
) -> None: ) -> None:
"""Test that the UDP server raises an error if started after stopped.""" """Test that the UDP server raises an error if started after stopped."""
voice_assistant_udp_server_v1.close() voice_assistant_udp_pipeline_v1.close()
with ( with (
patch( patch(
"homeassistant.components.esphome.voice_assistant.UDP_PORT", "homeassistant.components.esphome.voice_assistant.UDP_PORT",
@ -283,37 +337,37 @@ async def test_udp_server_after_stopped(
), ),
pytest.raises(RuntimeError), pytest.raises(RuntimeError),
): ):
await voice_assistant_udp_server_v1.start_server() await voice_assistant_udp_pipeline_v1.start_server()
async def test_unknown_event_type( async def test_unknown_event_type(
hass: HomeAssistant, hass: HomeAssistant,
voice_assistant_udp_server_v1: VoiceAssistantUDPServer, voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None: ) -> None:
"""Test the UDP server does not call handle_event for unknown events.""" """Test the API pipeline does not call handle_event for unknown events."""
voice_assistant_udp_server_v1._event_callback( voice_assistant_api_pipeline._event_callback(
PipelineEvent( PipelineEvent(
type="unknown-event", type="unknown-event",
data={}, data={},
) )
) )
assert not voice_assistant_udp_server_v1.handle_event.called assert not voice_assistant_api_pipeline.handle_event.called
async def test_error_event_type( async def test_error_event_type(
hass: HomeAssistant, hass: HomeAssistant,
voice_assistant_udp_server_v1: VoiceAssistantUDPServer, voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None: ) -> None:
"""Test the UDP server calls event handler with error.""" """Test the API pipeline calls event handler with error."""
voice_assistant_udp_server_v1._event_callback( voice_assistant_api_pipeline._event_callback(
PipelineEvent( PipelineEvent(
type=PipelineEventType.ERROR, type=PipelineEventType.ERROR,
data={"code": "code", "message": "message"}, data={"code": "code", "message": "message"},
) )
) )
voice_assistant_udp_server_v1.handle_event.assert_called_with( voice_assistant_api_pipeline.handle_event.assert_called_with(
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{"code": "code", "message": "message"}, {"code": "code", "message": "message"},
) )
@ -321,13 +375,13 @@ async def test_error_event_type(
async def test_send_tts_not_called( async def test_send_tts_not_called(
hass: HomeAssistant, hass: HomeAssistant,
voice_assistant_udp_server_v1: VoiceAssistantUDPServer, voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
) -> None: ) -> None:
"""Test the UDP server with a v1 device does not call _send_tts.""" """Test the UDP server with a v1 device does not call _send_tts."""
with patch( with patch(
"homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPServer._send_tts" "homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts"
) as mock_send_tts: ) as mock_send_tts:
voice_assistant_udp_server_v1._event_callback( voice_assistant_udp_pipeline_v1._event_callback(
PipelineEvent( PipelineEvent(
type=PipelineEventType.TTS_END, type=PipelineEventType.TTS_END,
data={ data={
@ -339,15 +393,35 @@ async def test_send_tts_not_called(
mock_send_tts.assert_not_called() mock_send_tts.assert_not_called()
async def test_send_tts_called( async def test_send_tts_called_udp(
hass: HomeAssistant, hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer, voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
) -> None: ) -> None:
"""Test the UDP server with a v2 device calls _send_tts.""" """Test the UDP server with a v2 device calls _send_tts."""
with patch( with patch(
"homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPServer._send_tts" "homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts"
) as mock_send_tts: ) as mock_send_tts:
voice_assistant_udp_server_v2._event_callback( voice_assistant_udp_pipeline_v2._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
},
)
)
mock_send_tts.assert_called_with(_TEST_MEDIA_ID)
async def test_send_tts_called_api(
hass: HomeAssistant,
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None:
"""Test the API pipeline calls _send_tts."""
with patch(
"homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts"
) as mock_send_tts:
voice_assistant_api_pipeline._event_callback(
PipelineEvent( PipelineEvent(
type=PipelineEventType.TTS_END, type=PipelineEventType.TTS_END,
data={ data={
@ -361,29 +435,36 @@ async def test_send_tts_called(
async def test_send_tts_not_called_when_empty( async def test_send_tts_not_called_when_empty(
hass: HomeAssistant, hass: HomeAssistant,
voice_assistant_udp_server_v1: VoiceAssistantUDPServer, voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer, voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None: ) -> None:
"""Test the UDP server with a v1/v2 device doesn't call _send_tts when the output is empty.""" """Test the pipelines do not call _send_tts when the output is empty."""
with patch( with patch(
"homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPServer._send_tts" "homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts"
) as mock_send_tts: ) as mock_send_tts:
voice_assistant_udp_server_v1._event_callback( voice_assistant_udp_pipeline_v1._event_callback(
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}}) PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
) )
mock_send_tts.assert_not_called() mock_send_tts.assert_not_called()
voice_assistant_udp_server_v2._event_callback( voice_assistant_udp_pipeline_v2._event_callback(
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
)
mock_send_tts.assert_not_called()
voice_assistant_api_pipeline._event_callback(
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}}) PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
) )
mock_send_tts.assert_not_called() mock_send_tts.assert_not_called()
async def test_send_tts( async def test_send_tts_udp(
hass: HomeAssistant, hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer, voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
test_wav, test_wav,
) -> None: ) -> None:
"""Test the UDP server calls sendto to transmit audio data to device.""" """Test the UDP server calls sendto to transmit audio data to device."""
@ -391,12 +472,12 @@ async def test_send_tts(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("wav", test_wav), return_value=("wav", test_wav),
): ):
voice_assistant_udp_server_v2.started = True voice_assistant_udp_pipeline_v2.started = True
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport) voice_assistant_udp_pipeline_v2.transport = Mock(spec=asyncio.DatagramTransport)
with patch.object( with patch.object(
voice_assistant_udp_server_v2.transport, "is_closing", return_value=False voice_assistant_udp_pipeline_v2.transport, "is_closing", return_value=False
): ):
voice_assistant_udp_server_v2._event_callback( voice_assistant_udp_pipeline_v2._event_callback(
PipelineEvent( PipelineEvent(
type=PipelineEventType.TTS_END, type=PipelineEventType.TTS_END,
data={ data={
@ -408,16 +489,46 @@ async def test_send_tts(
) )
) )
await voice_assistant_udp_server_v2._tts_done.wait() await voice_assistant_udp_pipeline_v2._tts_done.wait()
voice_assistant_udp_server_v2.transport.sendto.assert_called() voice_assistant_udp_pipeline_v2.transport.sendto.assert_called()
async def test_send_tts_api(
hass: HomeAssistant,
mock_client: APIClient,
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
test_wav,
) -> None:
"""Test the API pipeline calls cli.send_voice_assistant_audio to transmit audio data to device."""
with patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("wav", test_wav),
):
voice_assistant_api_pipeline.started = True
voice_assistant_api_pipeline._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {
"media_id": _TEST_MEDIA_ID,
"url": _TEST_OUTPUT_URL,
}
},
)
)
await voice_assistant_api_pipeline._tts_done.wait()
mock_client.send_voice_assistant_audio.assert_called()
async def test_send_tts_wrong_sample_rate( async def test_send_tts_wrong_sample_rate(
hass: HomeAssistant, hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer, voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None: ) -> None:
"""Test the UDP server calls sendto to transmit audio data to device.""" """Test that only 16000Hz audio will be streamed."""
with io.BytesIO() as wav_io: with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file: with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(22050) wav_file.setframerate(22050)
@ -433,10 +544,10 @@ async def test_send_tts_wrong_sample_rate(
), ),
pytest.raises(ValueError), pytest.raises(ValueError),
): ):
voice_assistant_udp_server_v2.started = True voice_assistant_api_pipeline.started = True
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport) voice_assistant_api_pipeline.transport = Mock(spec=asyncio.DatagramTransport)
voice_assistant_udp_server_v2._event_callback( voice_assistant_api_pipeline._event_callback(
PipelineEvent( PipelineEvent(
type=PipelineEventType.TTS_END, type=PipelineEventType.TTS_END,
data={ data={
@ -445,13 +556,13 @@ async def test_send_tts_wrong_sample_rate(
) )
) )
assert voice_assistant_udp_server_v2._tts_task is not None assert voice_assistant_api_pipeline._tts_task is not None
await voice_assistant_udp_server_v2._tts_task # raises ValueError await voice_assistant_api_pipeline._tts_task # raises ValueError
async def test_send_tts_wrong_format( async def test_send_tts_wrong_format(
hass: HomeAssistant, hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer, voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None: ) -> None:
"""Test that only WAV audio will be streamed.""" """Test that only WAV audio will be streamed."""
with ( with (
@ -461,10 +572,10 @@ async def test_send_tts_wrong_format(
), ),
pytest.raises(ValueError), pytest.raises(ValueError),
): ):
voice_assistant_udp_server_v2.started = True voice_assistant_api_pipeline.started = True
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport) voice_assistant_api_pipeline.transport = Mock(spec=asyncio.DatagramTransport)
voice_assistant_udp_server_v2._event_callback( voice_assistant_api_pipeline._event_callback(
PipelineEvent( PipelineEvent(
type=PipelineEventType.TTS_END, type=PipelineEventType.TTS_END,
data={ data={
@ -473,13 +584,13 @@ async def test_send_tts_wrong_format(
) )
) )
assert voice_assistant_udp_server_v2._tts_task is not None assert voice_assistant_api_pipeline._tts_task is not None
await voice_assistant_udp_server_v2._tts_task # raises ValueError await voice_assistant_api_pipeline._tts_task # raises ValueError
async def test_send_tts_not_started( async def test_send_tts_not_started(
hass: HomeAssistant, hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer, voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
test_wav, test_wav,
) -> None: ) -> None:
"""Test the UDP server does not call sendto when not started.""" """Test the UDP server does not call sendto when not started."""
@ -487,10 +598,10 @@ async def test_send_tts_not_started(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("wav", test_wav), return_value=("wav", test_wav),
): ):
voice_assistant_udp_server_v2.started = False voice_assistant_udp_pipeline_v2.started = False
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport) voice_assistant_udp_pipeline_v2.transport = Mock(spec=asyncio.DatagramTransport)
voice_assistant_udp_server_v2._event_callback( voice_assistant_udp_pipeline_v2._event_callback(
PipelineEvent( PipelineEvent(
type=PipelineEventType.TTS_END, type=PipelineEventType.TTS_END,
data={ data={
@ -499,14 +610,41 @@ async def test_send_tts_not_started(
) )
) )
await voice_assistant_udp_server_v2._tts_done.wait() await voice_assistant_udp_pipeline_v2._tts_done.wait()
voice_assistant_udp_server_v2.transport.sendto.assert_not_called() voice_assistant_udp_pipeline_v2.transport.sendto.assert_not_called()
async def test_send_tts_transport_none(
hass: HomeAssistant,
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
test_wav,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test the UDP server does not call sendto when transport is None."""
with patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("wav", test_wav),
):
voice_assistant_udp_pipeline_v2.started = True
voice_assistant_udp_pipeline_v2.transport = None
voice_assistant_udp_pipeline_v2._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
},
)
)
await voice_assistant_udp_pipeline_v2._tts_done.wait()
assert "No transport to send audio to" in caplog.text
async def test_wake_word( async def test_wake_word(
hass: HomeAssistant, hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer, voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None: ) -> None:
"""Test that the pipeline is set to start with Wake word.""" """Test that the pipeline is set to start with Wake word."""
@ -520,9 +658,7 @@ async def test_wake_word(
), ),
patch("asyncio.Event.wait"), # TTS wait event patch("asyncio.Event.wait"), # TTS wait event
): ):
voice_assistant_udp_server_v2.transport = Mock() await voice_assistant_api_pipeline.run_pipeline(
await voice_assistant_udp_server_v2.run_pipeline(
device_id="mock-device-id", device_id="mock-device-id",
conversation_id=None, conversation_id=None,
flags=2, flags=2,
@ -531,7 +667,7 @@ async def test_wake_word(
async def test_wake_word_exception( async def test_wake_word_exception(
hass: HomeAssistant, hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer, voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None: ) -> None:
"""Test that the pipeline is set to start with Wake word.""" """Test that the pipeline is set to start with Wake word."""
@ -542,7 +678,6 @@ async def test_wake_word_exception(
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream, new=async_pipeline_from_audio_stream,
): ):
voice_assistant_udp_server_v2.transport = Mock()
def handle_event( def handle_event(
event_type: VoiceAssistantEventType, data: dict[str, str] | None event_type: VoiceAssistantEventType, data: dict[str, str] | None
@ -552,9 +687,9 @@ async def test_wake_word_exception(
assert data["code"] == "pipeline-not-found" assert data["code"] == "pipeline-not-found"
assert data["message"] == "Pipeline not found" assert data["message"] == "Pipeline not found"
voice_assistant_udp_server_v2.handle_event = handle_event voice_assistant_api_pipeline.handle_event = handle_event
await voice_assistant_udp_server_v2.run_pipeline( await voice_assistant_api_pipeline.run_pipeline(
device_id="mock-device-id", device_id="mock-device-id",
conversation_id=None, conversation_id=None,
flags=2, flags=2,
@ -563,7 +698,7 @@ async def test_wake_word_exception(
async def test_wake_word_abort_exception( async def test_wake_word_abort_exception(
hass: HomeAssistant, hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer, voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
) -> None: ) -> None:
"""Test that the pipeline is set to start with Wake word.""" """Test that the pipeline is set to start with Wake word."""
@ -575,13 +710,9 @@ async def test_wake_word_abort_exception(
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream, new=async_pipeline_from_audio_stream,
), ),
patch.object( patch.object(voice_assistant_api_pipeline, "handle_event") as mock_handle_event,
voice_assistant_udp_server_v2, "handle_event"
) as mock_handle_event,
): ):
voice_assistant_udp_server_v2.transport = Mock() await voice_assistant_api_pipeline.run_pipeline(
await voice_assistant_udp_server_v2.run_pipeline(
device_id="mock-device-id", device_id="mock-device-id",
conversation_id=None, conversation_id=None,
flags=2, flags=2,