Send/receive Voice Assistant audio via ESPHome native API (#114800)
* Protobuf audio test * Remove extraneous code * Rework voice assistant pipeline * Move variables * Fix reading flags * Dont directly put to queue * Bump aioesphomeapi to 24.0.0 * Update tests - Add more tests for API pipeline - Convert some udp tests to use api pipeline - Update fixtures for new device info flags * Fix bad merge --------- Co-authored-by: Michael Hansen <mike@rhasspy.org>pull/115322/head
parent
cad4c3c0c2
commit
68384bba67
|
@ -33,7 +33,9 @@ async def async_setup_entry(
|
||||||
|
|
||||||
entry_data = DomainData.get(hass).get_entry_data(entry)
|
entry_data = DomainData.get(hass).get_entry_data(entry)
|
||||||
assert entry_data.device_info is not None
|
assert entry_data.device_info is not None
|
||||||
if entry_data.device_info.voice_assistant_version:
|
if entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||||
|
entry_data.api_version
|
||||||
|
):
|
||||||
async_add_entities([EsphomeAssistInProgressBinarySensor(entry_data)])
|
async_add_entities([EsphomeAssistInProgressBinarySensor(entry_data)])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -257,7 +257,9 @@ class RuntimeEntryData:
|
||||||
if async_get_dashboard(hass):
|
if async_get_dashboard(hass):
|
||||||
needed_platforms.add(Platform.UPDATE)
|
needed_platforms.add(Platform.UPDATE)
|
||||||
|
|
||||||
if self.device_info and self.device_info.voice_assistant_version:
|
if self.device_info and self.device_info.voice_assistant_feature_flags_compat(
|
||||||
|
self.api_version
|
||||||
|
):
|
||||||
needed_platforms.add(Platform.BINARY_SENSOR)
|
needed_platforms.add(Platform.BINARY_SENSOR)
|
||||||
needed_platforms.add(Platform.SELECT)
|
needed_platforms.add(Platform.SELECT)
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ from aioesphomeapi import (
|
||||||
UserService,
|
UserService,
|
||||||
UserServiceArgType,
|
UserServiceArgType,
|
||||||
VoiceAssistantAudioSettings,
|
VoiceAssistantAudioSettings,
|
||||||
|
VoiceAssistantFeature,
|
||||||
)
|
)
|
||||||
from awesomeversion import AwesomeVersion
|
from awesomeversion import AwesomeVersion
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
@ -72,7 +73,11 @@ from .domain_data import DomainData
|
||||||
|
|
||||||
# Import config flow so that it's added to the registry
|
# Import config flow so that it's added to the registry
|
||||||
from .entry_data import RuntimeEntryData
|
from .entry_data import RuntimeEntryData
|
||||||
from .voice_assistant import VoiceAssistantUDPServer
|
from .voice_assistant import (
|
||||||
|
VoiceAssistantAPIPipeline,
|
||||||
|
VoiceAssistantPipeline,
|
||||||
|
VoiceAssistantUDPPipeline,
|
||||||
|
)
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -143,7 +148,7 @@ class ESPHomeManager:
|
||||||
"cli",
|
"cli",
|
||||||
"device_id",
|
"device_id",
|
||||||
"domain_data",
|
"domain_data",
|
||||||
"voice_assistant_udp_server",
|
"voice_assistant_pipeline",
|
||||||
"reconnect_logic",
|
"reconnect_logic",
|
||||||
"zeroconf_instance",
|
"zeroconf_instance",
|
||||||
"entry_data",
|
"entry_data",
|
||||||
|
@ -168,7 +173,7 @@ class ESPHomeManager:
|
||||||
self.cli = cli
|
self.cli = cli
|
||||||
self.device_id: str | None = None
|
self.device_id: str | None = None
|
||||||
self.domain_data = domain_data
|
self.domain_data = domain_data
|
||||||
self.voice_assistant_udp_server: VoiceAssistantUDPServer | None = None
|
self.voice_assistant_pipeline: VoiceAssistantPipeline | None = None
|
||||||
self.reconnect_logic: ReconnectLogic | None = None
|
self.reconnect_logic: ReconnectLogic | None = None
|
||||||
self.zeroconf_instance = zeroconf_instance
|
self.zeroconf_instance = zeroconf_instance
|
||||||
self.entry_data = entry_data
|
self.entry_data = entry_data
|
||||||
|
@ -327,9 +332,10 @@ class ESPHomeManager:
|
||||||
def _handle_pipeline_finished(self) -> None:
|
def _handle_pipeline_finished(self) -> None:
|
||||||
self.entry_data.async_set_assist_pipeline_state(False)
|
self.entry_data.async_set_assist_pipeline_state(False)
|
||||||
|
|
||||||
if self.voice_assistant_udp_server is not None:
|
if self.voice_assistant_pipeline is not None:
|
||||||
self.voice_assistant_udp_server.close()
|
if isinstance(self.voice_assistant_pipeline, VoiceAssistantUDPPipeline):
|
||||||
self.voice_assistant_udp_server = None
|
self.voice_assistant_pipeline.close()
|
||||||
|
self.voice_assistant_pipeline = None
|
||||||
|
|
||||||
async def _handle_pipeline_start(
|
async def _handle_pipeline_start(
|
||||||
self,
|
self,
|
||||||
|
@ -339,38 +345,60 @@ class ESPHomeManager:
|
||||||
wake_word_phrase: str | None,
|
wake_word_phrase: str | None,
|
||||||
) -> int | None:
|
) -> int | None:
|
||||||
"""Start a voice assistant pipeline."""
|
"""Start a voice assistant pipeline."""
|
||||||
if self.voice_assistant_udp_server is not None:
|
if self.voice_assistant_pipeline is not None:
|
||||||
_LOGGER.warning("Voice assistant UDP server was not stopped")
|
_LOGGER.warning("Voice assistant UDP server was not stopped")
|
||||||
self.voice_assistant_udp_server.stop()
|
self.voice_assistant_pipeline.stop()
|
||||||
self.voice_assistant_udp_server = None
|
self.voice_assistant_pipeline = None
|
||||||
|
|
||||||
hass = self.hass
|
hass = self.hass
|
||||||
self.voice_assistant_udp_server = VoiceAssistantUDPServer(
|
assert self.entry_data.device_info is not None
|
||||||
hass,
|
if (
|
||||||
self.entry_data,
|
self.entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||||
self.cli.send_voice_assistant_event,
|
self.entry_data.api_version
|
||||||
self._handle_pipeline_finished,
|
)
|
||||||
)
|
& VoiceAssistantFeature.API_AUDIO
|
||||||
port = await self.voice_assistant_udp_server.start_server()
|
):
|
||||||
|
self.voice_assistant_pipeline = VoiceAssistantAPIPipeline(
|
||||||
|
hass,
|
||||||
|
self.entry_data,
|
||||||
|
self.cli.send_voice_assistant_event,
|
||||||
|
self._handle_pipeline_finished,
|
||||||
|
self.cli,
|
||||||
|
)
|
||||||
|
port = 0
|
||||||
|
else:
|
||||||
|
self.voice_assistant_pipeline = VoiceAssistantUDPPipeline(
|
||||||
|
hass,
|
||||||
|
self.entry_data,
|
||||||
|
self.cli.send_voice_assistant_event,
|
||||||
|
self._handle_pipeline_finished,
|
||||||
|
)
|
||||||
|
port = await self.voice_assistant_pipeline.start_server()
|
||||||
|
|
||||||
assert self.device_id is not None, "Device ID must be set"
|
assert self.device_id is not None, "Device ID must be set"
|
||||||
hass.async_create_background_task(
|
hass.async_create_background_task(
|
||||||
self.voice_assistant_udp_server.run_pipeline(
|
self.voice_assistant_pipeline.run_pipeline(
|
||||||
device_id=self.device_id,
|
device_id=self.device_id,
|
||||||
conversation_id=conversation_id or None,
|
conversation_id=conversation_id or None,
|
||||||
flags=flags,
|
flags=flags,
|
||||||
audio_settings=audio_settings,
|
audio_settings=audio_settings,
|
||||||
wake_word_phrase=wake_word_phrase,
|
wake_word_phrase=wake_word_phrase,
|
||||||
),
|
),
|
||||||
"esphome.voice_assistant_udp_server.run_pipeline",
|
"esphome.voice_assistant_pipeline.run_pipeline",
|
||||||
)
|
)
|
||||||
|
|
||||||
return port
|
return port
|
||||||
|
|
||||||
async def _handle_pipeline_stop(self) -> None:
|
async def _handle_pipeline_stop(self) -> None:
|
||||||
"""Stop a voice assistant pipeline."""
|
"""Stop a voice assistant pipeline."""
|
||||||
if self.voice_assistant_udp_server is not None:
|
if self.voice_assistant_pipeline is not None:
|
||||||
self.voice_assistant_udp_server.stop()
|
self.voice_assistant_pipeline.stop()
|
||||||
|
|
||||||
|
async def _handle_audio(self, data: bytes) -> None:
|
||||||
|
if self.voice_assistant_pipeline is None:
|
||||||
|
return
|
||||||
|
assert isinstance(self.voice_assistant_pipeline, VoiceAssistantAPIPipeline)
|
||||||
|
self.voice_assistant_pipeline.receive_audio_bytes(data)
|
||||||
|
|
||||||
async def on_connect(self) -> None:
|
async def on_connect(self) -> None:
|
||||||
"""Subscribe to states and list entities on successful API login."""
|
"""Subscribe to states and list entities on successful API login."""
|
||||||
|
@ -472,13 +500,23 @@ class ESPHomeManager:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if device_info.voice_assistant_version:
|
flags = device_info.voice_assistant_feature_flags_compat(api_version)
|
||||||
entry_data.disconnect_callbacks.add(
|
if flags:
|
||||||
cli.subscribe_voice_assistant(
|
if flags & VoiceAssistantFeature.API_AUDIO:
|
||||||
self._handle_pipeline_start,
|
entry_data.disconnect_callbacks.add(
|
||||||
self._handle_pipeline_stop,
|
cli.subscribe_voice_assistant(
|
||||||
|
handle_start=self._handle_pipeline_start,
|
||||||
|
handle_stop=self._handle_pipeline_stop,
|
||||||
|
handle_audio=self._handle_audio,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
entry_data.disconnect_callbacks.add(
|
||||||
|
cli.subscribe_voice_assistant(
|
||||||
|
handle_start=self._handle_pipeline_start,
|
||||||
|
handle_stop=self._handle_pipeline_stop,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
|
||||||
cli.subscribe_states(entry_data.async_update_state)
|
cli.subscribe_states(entry_data.async_update_state)
|
||||||
cli.subscribe_service_calls(self.async_on_service_call)
|
cli.subscribe_service_calls(self.async_on_service_call)
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
"iot_class": "local_push",
|
"iot_class": "local_push",
|
||||||
"loggers": ["aioesphomeapi", "noiseprotocol", "bleak_esphome"],
|
"loggers": ["aioesphomeapi", "noiseprotocol", "bleak_esphome"],
|
||||||
"requirements": [
|
"requirements": [
|
||||||
"aioesphomeapi==23.2.0",
|
"aioesphomeapi==24.0.0",
|
||||||
"esphome-dashboard-api==1.2.3",
|
"esphome-dashboard-api==1.2.3",
|
||||||
"bleak-esphome==1.0.0"
|
"bleak-esphome==1.0.0"
|
||||||
],
|
],
|
||||||
|
|
|
@ -42,7 +42,9 @@ async def async_setup_entry(
|
||||||
|
|
||||||
entry_data = DomainData.get(hass).get_entry_data(entry)
|
entry_data = DomainData.get(hass).get_entry_data(entry)
|
||||||
assert entry_data.device_info is not None
|
assert entry_data.device_info is not None
|
||||||
if entry_data.device_info.voice_assistant_version:
|
if entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||||
|
entry_data.api_version
|
||||||
|
):
|
||||||
async_add_entities(
|
async_add_entities(
|
||||||
[
|
[
|
||||||
EsphomeAssistPipelineSelect(hass, entry_data),
|
EsphomeAssistPipelineSelect(hass, entry_data),
|
||||||
|
|
|
@ -11,9 +11,11 @@ from typing import cast
|
||||||
import wave
|
import wave
|
||||||
|
|
||||||
from aioesphomeapi import (
|
from aioesphomeapi import (
|
||||||
|
APIClient,
|
||||||
VoiceAssistantAudioSettings,
|
VoiceAssistantAudioSettings,
|
||||||
VoiceAssistantCommandFlag,
|
VoiceAssistantCommandFlag,
|
||||||
VoiceAssistantEventType,
|
VoiceAssistantEventType,
|
||||||
|
VoiceAssistantFeature,
|
||||||
)
|
)
|
||||||
|
|
||||||
from homeassistant.components import stt, tts
|
from homeassistant.components import stt, tts
|
||||||
|
@ -64,13 +66,11 @@ _VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
class VoiceAssistantPipeline:
|
||||||
"""Receive UDP packets and forward them to the voice assistant."""
|
"""Base abstract pipeline class."""
|
||||||
|
|
||||||
started = False
|
started = False
|
||||||
stop_requested = False
|
stop_requested = False
|
||||||
transport: asyncio.DatagramTransport | None = None
|
|
||||||
remote_addr: tuple[str, int] | None = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -79,12 +79,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||||
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
|
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
|
||||||
handle_finished: Callable[[], None],
|
handle_finished: Callable[[], None],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize UDP receiver."""
|
"""Initialize the pipeline."""
|
||||||
self.context = Context()
|
self.context = Context()
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
|
|
||||||
assert entry_data.device_info is not None
|
|
||||||
self.entry_data = entry_data
|
self.entry_data = entry_data
|
||||||
|
assert entry_data.device_info is not None
|
||||||
self.device_info = entry_data.device_info
|
self.device_info = entry_data.device_info
|
||||||
|
|
||||||
self.queue: asyncio.Queue[bytes] = asyncio.Queue()
|
self.queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||||
|
@ -95,69 +94,9 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_running(self) -> bool:
|
def is_running(self) -> bool:
|
||||||
"""True if the UDP server is started and hasn't been asked to stop."""
|
"""True if the pipeline is started and hasn't been asked to stop."""
|
||||||
return self.started and (not self.stop_requested)
|
return self.started and (not self.stop_requested)
|
||||||
|
|
||||||
async def start_server(self) -> int:
|
|
||||||
"""Start accepting connections."""
|
|
||||||
|
|
||||||
def accept_connection() -> VoiceAssistantUDPServer:
|
|
||||||
"""Accept connection."""
|
|
||||||
if self.started:
|
|
||||||
raise RuntimeError("Can only start once")
|
|
||||||
if self.stop_requested:
|
|
||||||
raise RuntimeError("No longer accepting connections")
|
|
||||||
|
|
||||||
self.started = True
|
|
||||||
return self
|
|
||||||
|
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
||||||
sock.setblocking(False)
|
|
||||||
|
|
||||||
sock.bind(("", UDP_PORT))
|
|
||||||
|
|
||||||
await asyncio.get_running_loop().create_datagram_endpoint(
|
|
||||||
accept_connection, sock=sock
|
|
||||||
)
|
|
||||||
|
|
||||||
return cast(int, sock.getsockname()[1])
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
|
||||||
"""Store transport for later use."""
|
|
||||||
self.transport = cast(asyncio.DatagramTransport, transport)
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
|
|
||||||
"""Handle incoming UDP packet."""
|
|
||||||
if not self.is_running:
|
|
||||||
return
|
|
||||||
if self.remote_addr is None:
|
|
||||||
self.remote_addr = addr
|
|
||||||
self.queue.put_nowait(data)
|
|
||||||
|
|
||||||
def error_received(self, exc: Exception) -> None:
|
|
||||||
"""Handle when a send or receive operation raises an OSError.
|
|
||||||
|
|
||||||
(Other than BlockingIOError or InterruptedError.)
|
|
||||||
"""
|
|
||||||
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
|
|
||||||
self.handle_finished()
|
|
||||||
|
|
||||||
@callback
|
|
||||||
def stop(self) -> None:
|
|
||||||
"""Stop the receiver."""
|
|
||||||
self.queue.put_nowait(b"")
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
def close(self) -> None:
|
|
||||||
"""Close the receiver."""
|
|
||||||
self.started = False
|
|
||||||
self.stop_requested = True
|
|
||||||
|
|
||||||
if self.transport is not None:
|
|
||||||
self.transport.close()
|
|
||||||
|
|
||||||
async def _iterate_packets(self) -> AsyncIterable[bytes]:
|
async def _iterate_packets(self) -> AsyncIterable[bytes]:
|
||||||
"""Iterate over incoming packets."""
|
"""Iterate over incoming packets."""
|
||||||
while data := await self.queue.get():
|
while data := await self.queue.get():
|
||||||
|
@ -198,7 +137,12 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||||
url = async_process_play_media_url(self.hass, path)
|
url = async_process_play_media_url(self.hass, path)
|
||||||
data_to_send = {"url": url}
|
data_to_send = {"url": url}
|
||||||
|
|
||||||
if self.device_info.voice_assistant_version >= 2:
|
if (
|
||||||
|
self.device_info.voice_assistant_feature_flags_compat(
|
||||||
|
self.entry_data.api_version
|
||||||
|
)
|
||||||
|
& VoiceAssistantFeature.SPEAKER
|
||||||
|
):
|
||||||
media_id = tts_output["media_id"]
|
media_id = tts_output["media_id"]
|
||||||
self._tts_task = self.hass.async_create_background_task(
|
self._tts_task = self.hass.async_create_background_task(
|
||||||
self._send_tts(media_id), "esphome_voice_assistant_tts"
|
self._send_tts(media_id), "esphome_voice_assistant_tts"
|
||||||
|
@ -243,9 +187,15 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||||
if audio_settings is None or audio_settings.volume_multiplier == 0:
|
if audio_settings is None or audio_settings.volume_multiplier == 0:
|
||||||
audio_settings = VoiceAssistantAudioSettings()
|
audio_settings = VoiceAssistantAudioSettings()
|
||||||
|
|
||||||
tts_audio_output = (
|
if (
|
||||||
"wav" if self.device_info.voice_assistant_version >= 2 else "mp3"
|
self.device_info.voice_assistant_feature_flags_compat(
|
||||||
)
|
self.entry_data.api_version
|
||||||
|
)
|
||||||
|
& VoiceAssistantFeature.SPEAKER
|
||||||
|
):
|
||||||
|
tts_audio_output = "wav"
|
||||||
|
else:
|
||||||
|
tts_audio_output = "mp3"
|
||||||
|
|
||||||
_LOGGER.debug("Starting pipeline")
|
_LOGGER.debug("Starting pipeline")
|
||||||
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
|
if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD:
|
||||||
|
@ -315,7 +265,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||||
self.handle_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {})
|
self.handle_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if (not self.is_running) or (self.transport is None):
|
if not self.is_running:
|
||||||
return
|
return
|
||||||
|
|
||||||
extension, data = await tts.async_get_media_source_audio(
|
extension, data = await tts.async_get_media_source_audio(
|
||||||
|
@ -358,16 +308,133 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||||
samples_in_chunk = len(chunk) // bytes_per_sample
|
samples_in_chunk = len(chunk) // bytes_per_sample
|
||||||
samples_left -= samples_in_chunk
|
samples_left -= samples_in_chunk
|
||||||
|
|
||||||
self.transport.sendto(chunk, self.remote_addr)
|
self.send_audio_bytes(chunk)
|
||||||
await asyncio.sleep(
|
await asyncio.sleep(
|
||||||
samples_in_chunk / stt.AudioSampleRates.SAMPLERATE_16000 * 0.9
|
samples_in_chunk / stt.AudioSampleRates.SAMPLERATE_16000 * 0.9
|
||||||
)
|
)
|
||||||
|
|
||||||
sample_offset += samples_in_chunk
|
sample_offset += samples_in_chunk
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
self.handle_event(
|
self.handle_event(
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}
|
||||||
)
|
)
|
||||||
self._tts_task = None
|
self._tts_task = None
|
||||||
self._tts_done.set()
|
self._tts_done.set()
|
||||||
|
|
||||||
|
def send_audio_bytes(self, data: bytes) -> None:
|
||||||
|
"""Send bytes to the device."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop the pipeline."""
|
||||||
|
self.queue.put_nowait(b"")
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceAssistantUDPPipeline(asyncio.DatagramProtocol, VoiceAssistantPipeline):
|
||||||
|
"""Receive UDP packets and forward them to the voice assistant."""
|
||||||
|
|
||||||
|
transport: asyncio.DatagramTransport | None = None
|
||||||
|
remote_addr: tuple[str, int] | None = None
|
||||||
|
|
||||||
|
async def start_server(self) -> int:
|
||||||
|
"""Start accepting connections."""
|
||||||
|
|
||||||
|
def accept_connection() -> VoiceAssistantUDPPipeline:
|
||||||
|
"""Accept connection."""
|
||||||
|
if self.started:
|
||||||
|
raise RuntimeError("Can only start once")
|
||||||
|
if self.stop_requested:
|
||||||
|
raise RuntimeError("No longer accepting connections")
|
||||||
|
|
||||||
|
self.started = True
|
||||||
|
return self
|
||||||
|
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
sock.setblocking(False)
|
||||||
|
|
||||||
|
sock.bind(("", UDP_PORT))
|
||||||
|
|
||||||
|
await asyncio.get_running_loop().create_datagram_endpoint(
|
||||||
|
accept_connection, sock=sock
|
||||||
|
)
|
||||||
|
|
||||||
|
return cast(int, sock.getsockname()[1])
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||||
|
"""Store transport for later use."""
|
||||||
|
self.transport = cast(asyncio.DatagramTransport, transport)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
|
||||||
|
"""Handle incoming UDP packet."""
|
||||||
|
if not self.is_running:
|
||||||
|
return
|
||||||
|
if self.remote_addr is None:
|
||||||
|
self.remote_addr = addr
|
||||||
|
self.queue.put_nowait(data)
|
||||||
|
|
||||||
|
def error_received(self, exc: Exception) -> None:
|
||||||
|
"""Handle when a send or receive operation raises an OSError.
|
||||||
|
|
||||||
|
(Other than BlockingIOError or InterruptedError.)
|
||||||
|
"""
|
||||||
|
_LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc)
|
||||||
|
self.handle_finished()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop the receiver."""
|
||||||
|
super().stop()
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close the receiver."""
|
||||||
|
self.started = False
|
||||||
|
self.stop_requested = True
|
||||||
|
|
||||||
|
if self.transport is not None:
|
||||||
|
self.transport.close()
|
||||||
|
|
||||||
|
def send_audio_bytes(self, data: bytes) -> None:
|
||||||
|
"""Send bytes to the device via UDP."""
|
||||||
|
if self.transport is None:
|
||||||
|
_LOGGER.error("No transport to send audio to")
|
||||||
|
return
|
||||||
|
self.transport.sendto(data, self.remote_addr)
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceAssistantAPIPipeline(VoiceAssistantPipeline):
|
||||||
|
"""Send audio to the voice assistant via the API."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
entry_data: RuntimeEntryData,
|
||||||
|
handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None],
|
||||||
|
handle_finished: Callable[[], None],
|
||||||
|
api_client: APIClient,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the pipeline."""
|
||||||
|
super().__init__(hass, entry_data, handle_event, handle_finished)
|
||||||
|
self.api_client = api_client
|
||||||
|
self.started = True
|
||||||
|
|
||||||
|
def send_audio_bytes(self, data: bytes) -> None:
|
||||||
|
"""Send bytes to the device via the API."""
|
||||||
|
self.api_client.send_voice_assistant_audio(data)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def receive_audio_bytes(self, data: bytes) -> None:
|
||||||
|
"""Receive audio bytes from the device."""
|
||||||
|
if not self.is_running:
|
||||||
|
return
|
||||||
|
self.queue.put_nowait(data)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Stop the pipeline."""
|
||||||
|
super().stop()
|
||||||
|
|
||||||
|
self.started = False
|
||||||
|
self.stop_requested = True
|
||||||
|
|
|
@ -242,7 +242,7 @@ aioelectricitymaps==0.4.0
|
||||||
aioemonitor==1.0.5
|
aioemonitor==1.0.5
|
||||||
|
|
||||||
# homeassistant.components.esphome
|
# homeassistant.components.esphome
|
||||||
aioesphomeapi==23.2.0
|
aioesphomeapi==24.0.0
|
||||||
|
|
||||||
# homeassistant.components.flo
|
# homeassistant.components.flo
|
||||||
aioflo==2021.11.0
|
aioflo==2021.11.0
|
||||||
|
|
|
@ -221,7 +221,7 @@ aioelectricitymaps==0.4.0
|
||||||
aioemonitor==1.0.5
|
aioemonitor==1.0.5
|
||||||
|
|
||||||
# homeassistant.components.esphome
|
# homeassistant.components.esphome
|
||||||
aioesphomeapi==23.2.0
|
aioesphomeapi==24.0.0
|
||||||
|
|
||||||
# homeassistant.components.flo
|
# homeassistant.components.flo
|
||||||
aioflo==2021.11.0
|
aioflo==2021.11.0
|
||||||
|
|
|
@ -18,6 +18,7 @@ from aioesphomeapi import (
|
||||||
HomeassistantServiceCall,
|
HomeassistantServiceCall,
|
||||||
ReconnectLogic,
|
ReconnectLogic,
|
||||||
UserService,
|
UserService,
|
||||||
|
VoiceAssistantFeature,
|
||||||
)
|
)
|
||||||
import pytest
|
import pytest
|
||||||
from zeroconf import Zeroconf
|
from zeroconf import Zeroconf
|
||||||
|
@ -354,10 +355,16 @@ async def mock_voice_assistant_entry(
|
||||||
):
|
):
|
||||||
"""Set up an ESPHome entry with voice assistant."""
|
"""Set up an ESPHome entry with voice assistant."""
|
||||||
|
|
||||||
async def _mock_voice_assistant_entry(version: int) -> MockConfigEntry:
|
async def _mock_voice_assistant_entry(
|
||||||
|
voice_assistant_feature_flags: VoiceAssistantFeature,
|
||||||
|
) -> MockConfigEntry:
|
||||||
return (
|
return (
|
||||||
await _mock_generic_device_entry(
|
await _mock_generic_device_entry(
|
||||||
hass, mock_client, {"voice_assistant_version": version}, ([], []), []
|
hass,
|
||||||
|
mock_client,
|
||||||
|
{"voice_assistant_feature_flags": voice_assistant_feature_flags},
|
||||||
|
([], []),
|
||||||
|
[],
|
||||||
)
|
)
|
||||||
).entry
|
).entry
|
||||||
|
|
||||||
|
@ -367,13 +374,28 @@ async def mock_voice_assistant_entry(
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def mock_voice_assistant_v1_entry(mock_voice_assistant_entry) -> MockConfigEntry:
|
async def mock_voice_assistant_v1_entry(mock_voice_assistant_entry) -> MockConfigEntry:
|
||||||
"""Set up an ESPHome entry with voice assistant."""
|
"""Set up an ESPHome entry with voice assistant."""
|
||||||
return await mock_voice_assistant_entry(version=1)
|
return await mock_voice_assistant_entry(
|
||||||
|
voice_assistant_feature_flags=VoiceAssistantFeature.VOICE_ASSISTANT
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def mock_voice_assistant_v2_entry(mock_voice_assistant_entry) -> MockConfigEntry:
|
async def mock_voice_assistant_v2_entry(mock_voice_assistant_entry) -> MockConfigEntry:
|
||||||
"""Set up an ESPHome entry with voice assistant."""
|
"""Set up an ESPHome entry with voice assistant."""
|
||||||
return await mock_voice_assistant_entry(version=2)
|
return await mock_voice_assistant_entry(
|
||||||
|
voice_assistant_feature_flags=VoiceAssistantFeature.VOICE_ASSISTANT
|
||||||
|
| VoiceAssistantFeature.SPEAKER
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mock_voice_assistant_api_entry(mock_voice_assistant_entry) -> MockConfigEntry:
|
||||||
|
"""Set up an ESPHome entry with voice assistant."""
|
||||||
|
return await mock_voice_assistant_entry(
|
||||||
|
voice_assistant_feature_flags=VoiceAssistantFeature.VOICE_ASSISTANT
|
||||||
|
| VoiceAssistantFeature.SPEAKER
|
||||||
|
| VoiceAssistantFeature.API_AUDIO
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
@ -94,7 +94,8 @@ async def test_diagnostics_with_bluetooth(
|
||||||
"project_version": "",
|
"project_version": "",
|
||||||
"suggested_area": "",
|
"suggested_area": "",
|
||||||
"uses_password": False,
|
"uses_password": False,
|
||||||
"voice_assistant_version": 0,
|
"legacy_voice_assistant_version": 0,
|
||||||
|
"voice_assistant_feature_flags": 0,
|
||||||
"webserver_port": 0,
|
"webserver_port": 0,
|
||||||
},
|
},
|
||||||
"services": [],
|
"services": [],
|
||||||
|
|
|
@ -6,7 +6,7 @@ import socket
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
import wave
|
import wave
|
||||||
|
|
||||||
from aioesphomeapi import VoiceAssistantEventType
|
from aioesphomeapi import APIClient, VoiceAssistantEventType
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.assist_pipeline import (
|
from homeassistant.components.assist_pipeline import (
|
||||||
|
@ -19,7 +19,10 @@ from homeassistant.components.assist_pipeline.error import (
|
||||||
WakeWordDetectionError,
|
WakeWordDetectionError,
|
||||||
)
|
)
|
||||||
from homeassistant.components.esphome import DomainData
|
from homeassistant.components.esphome import DomainData
|
||||||
from homeassistant.components.esphome.voice_assistant import VoiceAssistantUDPServer
|
from homeassistant.components.esphome.voice_assistant import (
|
||||||
|
VoiceAssistantAPIPipeline,
|
||||||
|
VoiceAssistantUDPPipeline,
|
||||||
|
)
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
_TEST_INPUT_TEXT = "This is an input test"
|
_TEST_INPUT_TEXT = "This is an input test"
|
||||||
|
@ -31,43 +34,54 @@ _ONE_SECOND = 16000 * 2 # 16Khz 16-bit
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def voice_assistant_udp_server(
|
def voice_assistant_udp_pipeline(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
) -> VoiceAssistantUDPServer:
|
) -> VoiceAssistantUDPPipeline:
|
||||||
"""Return the UDP server factory."""
|
"""Return the UDP pipeline factory."""
|
||||||
|
|
||||||
def _voice_assistant_udp_server(entry):
|
def _voice_assistant_udp_server(entry):
|
||||||
entry_data = DomainData.get(hass).get_entry_data(entry)
|
entry_data = DomainData.get(hass).get_entry_data(entry)
|
||||||
|
|
||||||
server: VoiceAssistantUDPServer = None
|
server: VoiceAssistantUDPPipeline = None
|
||||||
|
|
||||||
def handle_finished():
|
def handle_finished():
|
||||||
nonlocal server
|
nonlocal server
|
||||||
assert server is not None
|
assert server is not None
|
||||||
server.close()
|
server.close()
|
||||||
|
|
||||||
server = VoiceAssistantUDPServer(hass, entry_data, Mock(), handle_finished)
|
server = VoiceAssistantUDPPipeline(hass, entry_data, Mock(), handle_finished)
|
||||||
return server # noqa: RET504
|
return server # noqa: RET504
|
||||||
|
|
||||||
return _voice_assistant_udp_server
|
return _voice_assistant_udp_server
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def voice_assistant_udp_server_v1(
|
def voice_assistant_api_pipeline(
|
||||||
voice_assistant_udp_server,
|
hass: HomeAssistant,
|
||||||
mock_voice_assistant_v1_entry,
|
mock_client,
|
||||||
) -> VoiceAssistantUDPServer:
|
mock_voice_assistant_api_entry,
|
||||||
"""Return the UDP server."""
|
) -> VoiceAssistantAPIPipeline:
|
||||||
return voice_assistant_udp_server(entry=mock_voice_assistant_v1_entry)
|
"""Return the API Pipeline factory."""
|
||||||
|
entry_data = DomainData.get(hass).get_entry_data(mock_voice_assistant_api_entry)
|
||||||
|
return VoiceAssistantAPIPipeline(hass, entry_data, Mock(), Mock(), mock_client)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def voice_assistant_udp_server_v2(
|
def voice_assistant_udp_pipeline_v1(
|
||||||
voice_assistant_udp_server,
|
voice_assistant_udp_pipeline,
|
||||||
|
mock_voice_assistant_v1_entry,
|
||||||
|
) -> VoiceAssistantUDPPipeline:
|
||||||
|
"""Return the UDP pipeline."""
|
||||||
|
return voice_assistant_udp_pipeline(entry=mock_voice_assistant_v1_entry)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def voice_assistant_udp_pipeline_v2(
|
||||||
|
voice_assistant_udp_pipeline,
|
||||||
mock_voice_assistant_v2_entry,
|
mock_voice_assistant_v2_entry,
|
||||||
) -> VoiceAssistantUDPServer:
|
) -> VoiceAssistantUDPPipeline:
|
||||||
"""Return the UDP server."""
|
"""Return the UDP pipeline."""
|
||||||
return voice_assistant_udp_server(entry=mock_voice_assistant_v2_entry)
|
return voice_assistant_udp_pipeline(entry=mock_voice_assistant_v2_entry)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -85,7 +99,7 @@ def test_wav() -> bytes:
|
||||||
|
|
||||||
async def test_pipeline_events(
|
async def test_pipeline_events(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
|
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that the pipeline function is called."""
|
"""Test that the pipeline function is called."""
|
||||||
|
|
||||||
|
@ -145,15 +159,15 @@ async def test_pipeline_events(
|
||||||
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
|
elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END:
|
||||||
assert data is None
|
assert data is None
|
||||||
|
|
||||||
voice_assistant_udp_server_v1.handle_event = handle_event
|
voice_assistant_udp_pipeline_v1.handle_event = handle_event
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
):
|
):
|
||||||
voice_assistant_udp_server_v1.transport = Mock()
|
voice_assistant_udp_pipeline_v1.transport = Mock()
|
||||||
|
|
||||||
await voice_assistant_udp_server_v1.run_pipeline(
|
await voice_assistant_udp_pipeline_v1.run_pipeline(
|
||||||
device_id="mock-device-id", conversation_id=None
|
device_id="mock-device-id", conversation_id=None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -162,7 +176,7 @@ async def test_udp_server(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
socket_enabled,
|
socket_enabled,
|
||||||
unused_udp_port_factory,
|
unused_udp_port_factory,
|
||||||
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
|
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the UDP server runs and queues incoming data."""
|
"""Test the UDP server runs and queues incoming data."""
|
||||||
port_to_use = unused_udp_port_factory()
|
port_to_use = unused_udp_port_factory()
|
||||||
|
@ -170,93 +184,133 @@ async def test_udp_server(
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT", new=port_to_use
|
"homeassistant.components.esphome.voice_assistant.UDP_PORT", new=port_to_use
|
||||||
):
|
):
|
||||||
port = await voice_assistant_udp_server_v1.start_server()
|
port = await voice_assistant_udp_pipeline_v1.start_server()
|
||||||
assert port == port_to_use
|
assert port == port_to_use
|
||||||
|
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
|
||||||
assert voice_assistant_udp_server_v1.queue.qsize() == 0
|
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 0
|
||||||
sock.sendto(b"test", ("127.0.0.1", port))
|
sock.sendto(b"test", ("127.0.0.1", port))
|
||||||
|
|
||||||
# Give the socket some time to send/receive the data
|
# Give the socket some time to send/receive the data
|
||||||
async with asyncio.timeout(1):
|
async with asyncio.timeout(1):
|
||||||
while voice_assistant_udp_server_v1.queue.qsize() == 0:
|
while voice_assistant_udp_pipeline_v1.queue.qsize() == 0:
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
assert voice_assistant_udp_server_v1.queue.qsize() == 1
|
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1
|
||||||
|
|
||||||
voice_assistant_udp_server_v1.stop()
|
voice_assistant_udp_pipeline_v1.stop()
|
||||||
voice_assistant_udp_server_v1.close()
|
voice_assistant_udp_pipeline_v1.close()
|
||||||
|
|
||||||
assert voice_assistant_udp_server_v1.transport.is_closing()
|
assert voice_assistant_udp_pipeline_v1.transport.is_closing()
|
||||||
|
|
||||||
|
|
||||||
async def test_udp_server_queue(
|
async def test_udp_server_queue(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
|
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the UDP server queues incoming data."""
|
"""Test the UDP server queues incoming data."""
|
||||||
|
|
||||||
voice_assistant_udp_server_v1.started = True
|
voice_assistant_udp_pipeline_v1.started = True
|
||||||
|
|
||||||
assert voice_assistant_udp_server_v1.queue.qsize() == 0
|
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 0
|
||||||
|
|
||||||
voice_assistant_udp_server_v1.datagram_received(bytes(1024), ("localhost", 0))
|
voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0))
|
||||||
assert voice_assistant_udp_server_v1.queue.qsize() == 1
|
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1
|
||||||
|
|
||||||
voice_assistant_udp_server_v1.datagram_received(bytes(1024), ("localhost", 0))
|
voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0))
|
||||||
assert voice_assistant_udp_server_v1.queue.qsize() == 2
|
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 2
|
||||||
|
|
||||||
async for data in voice_assistant_udp_server_v1._iterate_packets():
|
async for data in voice_assistant_udp_pipeline_v1._iterate_packets():
|
||||||
assert data == bytes(1024)
|
assert data == bytes(1024)
|
||||||
break
|
break
|
||||||
assert voice_assistant_udp_server_v1.queue.qsize() == 1 # One message removed
|
assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1 # One message removed
|
||||||
|
|
||||||
voice_assistant_udp_server_v1.stop()
|
voice_assistant_udp_pipeline_v1.stop()
|
||||||
assert (
|
assert (
|
||||||
voice_assistant_udp_server_v1.queue.qsize() == 2
|
voice_assistant_udp_pipeline_v1.queue.qsize() == 2
|
||||||
) # An empty message added by stop
|
) # An empty message added by stop
|
||||||
|
|
||||||
voice_assistant_udp_server_v1.datagram_received(bytes(1024), ("localhost", 0))
|
voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0))
|
||||||
assert (
|
assert (
|
||||||
voice_assistant_udp_server_v1.queue.qsize() == 2
|
voice_assistant_udp_pipeline_v1.queue.qsize() == 2
|
||||||
) # No new messages added after stop
|
) # No new messages added after stop
|
||||||
|
|
||||||
voice_assistant_udp_server_v1.close()
|
voice_assistant_udp_pipeline_v1.close()
|
||||||
|
|
||||||
# Stopping the UDP server should cause _iterate_packets to break out
|
# Stopping the UDP server should cause _iterate_packets to break out
|
||||||
# immediately without yielding any data.
|
# immediately without yielding any data.
|
||||||
has_data = False
|
has_data = False
|
||||||
async for _data in voice_assistant_udp_server_v1._iterate_packets():
|
async for _data in voice_assistant_udp_pipeline_v1._iterate_packets():
|
||||||
has_data = True
|
has_data = True
|
||||||
|
|
||||||
assert not has_data, "Server was stopped"
|
assert not has_data, "Server was stopped"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_api_pipeline_queue(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||||
|
) -> None:
|
||||||
|
"""Test the API pipeline queues incoming data."""
|
||||||
|
|
||||||
|
voice_assistant_api_pipeline.started = True
|
||||||
|
|
||||||
|
assert voice_assistant_api_pipeline.queue.qsize() == 0
|
||||||
|
|
||||||
|
voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024))
|
||||||
|
assert voice_assistant_api_pipeline.queue.qsize() == 1
|
||||||
|
|
||||||
|
voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024))
|
||||||
|
assert voice_assistant_api_pipeline.queue.qsize() == 2
|
||||||
|
|
||||||
|
async for data in voice_assistant_api_pipeline._iterate_packets():
|
||||||
|
assert data == bytes(1024)
|
||||||
|
break
|
||||||
|
assert voice_assistant_api_pipeline.queue.qsize() == 1 # One message removed
|
||||||
|
|
||||||
|
voice_assistant_api_pipeline.stop()
|
||||||
|
assert (
|
||||||
|
voice_assistant_api_pipeline.queue.qsize() == 2
|
||||||
|
) # An empty message added by stop
|
||||||
|
|
||||||
|
voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024))
|
||||||
|
assert (
|
||||||
|
voice_assistant_api_pipeline.queue.qsize() == 2
|
||||||
|
) # No new messages added after stop
|
||||||
|
|
||||||
|
# Stopping the API Pipeline should cause _iterate_packets to break out
|
||||||
|
# immediately without yielding any data.
|
||||||
|
has_data = False
|
||||||
|
async for _data in voice_assistant_api_pipeline._iterate_packets():
|
||||||
|
has_data = True
|
||||||
|
|
||||||
|
assert not has_data, "Pipeline was stopped"
|
||||||
|
|
||||||
|
|
||||||
async def test_error_calls_handle_finished(
|
async def test_error_calls_handle_finished(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
|
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that the handle_finished callback is called when an error occurs."""
|
"""Test that the handle_finished callback is called when an error occurs."""
|
||||||
voice_assistant_udp_server_v1.handle_finished = Mock()
|
voice_assistant_udp_pipeline_v1.handle_finished = Mock()
|
||||||
|
|
||||||
voice_assistant_udp_server_v1.error_received(Exception())
|
voice_assistant_udp_pipeline_v1.error_received(Exception())
|
||||||
|
|
||||||
voice_assistant_udp_server_v1.handle_finished.assert_called()
|
voice_assistant_udp_pipeline_v1.handle_finished.assert_called()
|
||||||
|
|
||||||
|
|
||||||
async def test_udp_server_multiple(
|
async def test_udp_server_multiple(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
socket_enabled,
|
socket_enabled,
|
||||||
unused_udp_port_factory,
|
unused_udp_port_factory,
|
||||||
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
|
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that the UDP server raises an error if started twice."""
|
"""Test that the UDP server raises an error if started twice."""
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
||||||
new=unused_udp_port_factory(),
|
new=unused_udp_port_factory(),
|
||||||
):
|
):
|
||||||
await voice_assistant_udp_server_v1.start_server()
|
await voice_assistant_udp_pipeline_v1.start_server()
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
|
@ -265,17 +319,17 @@ async def test_udp_server_multiple(
|
||||||
),
|
),
|
||||||
pytest.raises(RuntimeError),
|
pytest.raises(RuntimeError),
|
||||||
):
|
):
|
||||||
await voice_assistant_udp_server_v1.start_server()
|
await voice_assistant_udp_pipeline_v1.start_server()
|
||||||
|
|
||||||
|
|
||||||
async def test_udp_server_after_stopped(
|
async def test_udp_server_after_stopped(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
socket_enabled,
|
socket_enabled,
|
||||||
unused_udp_port_factory,
|
unused_udp_port_factory,
|
||||||
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
|
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that the UDP server raises an error if started after stopped."""
|
"""Test that the UDP server raises an error if started after stopped."""
|
||||||
voice_assistant_udp_server_v1.close()
|
voice_assistant_udp_pipeline_v1.close()
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
||||||
|
@ -283,37 +337,37 @@ async def test_udp_server_after_stopped(
|
||||||
),
|
),
|
||||||
pytest.raises(RuntimeError),
|
pytest.raises(RuntimeError),
|
||||||
):
|
):
|
||||||
await voice_assistant_udp_server_v1.start_server()
|
await voice_assistant_udp_pipeline_v1.start_server()
|
||||||
|
|
||||||
|
|
||||||
async def test_unknown_event_type(
|
async def test_unknown_event_type(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
|
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the UDP server does not call handle_event for unknown events."""
|
"""Test the API pipeline does not call handle_event for unknown events."""
|
||||||
voice_assistant_udp_server_v1._event_callback(
|
voice_assistant_api_pipeline._event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
type="unknown-event",
|
type="unknown-event",
|
||||||
data={},
|
data={},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert not voice_assistant_udp_server_v1.handle_event.called
|
assert not voice_assistant_api_pipeline.handle_event.called
|
||||||
|
|
||||||
|
|
||||||
async def test_error_event_type(
|
async def test_error_event_type(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
|
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the UDP server calls event handler with error."""
|
"""Test the API pipeline calls event handler with error."""
|
||||||
voice_assistant_udp_server_v1._event_callback(
|
voice_assistant_api_pipeline._event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
type=PipelineEventType.ERROR,
|
type=PipelineEventType.ERROR,
|
||||||
data={"code": "code", "message": "message"},
|
data={"code": "code", "message": "message"},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
voice_assistant_udp_server_v1.handle_event.assert_called_with(
|
voice_assistant_api_pipeline.handle_event.assert_called_with(
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
|
||||||
{"code": "code", "message": "message"},
|
{"code": "code", "message": "message"},
|
||||||
)
|
)
|
||||||
|
@ -321,13 +375,13 @@ async def test_error_event_type(
|
||||||
|
|
||||||
async def test_send_tts_not_called(
|
async def test_send_tts_not_called(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
|
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the UDP server with a v1 device does not call _send_tts."""
|
"""Test the UDP server with a v1 device does not call _send_tts."""
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPServer._send_tts"
|
"homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts"
|
||||||
) as mock_send_tts:
|
) as mock_send_tts:
|
||||||
voice_assistant_udp_server_v1._event_callback(
|
voice_assistant_udp_pipeline_v1._event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
type=PipelineEventType.TTS_END,
|
type=PipelineEventType.TTS_END,
|
||||||
data={
|
data={
|
||||||
|
@ -339,15 +393,35 @@ async def test_send_tts_not_called(
|
||||||
mock_send_tts.assert_not_called()
|
mock_send_tts.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
async def test_send_tts_called(
|
async def test_send_tts_called_udp(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the UDP server with a v2 device calls _send_tts."""
|
"""Test the UDP server with a v2 device calls _send_tts."""
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPServer._send_tts"
|
"homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts"
|
||||||
) as mock_send_tts:
|
) as mock_send_tts:
|
||||||
voice_assistant_udp_server_v2._event_callback(
|
voice_assistant_udp_pipeline_v2._event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.TTS_END,
|
||||||
|
data={
|
||||||
|
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_send_tts.assert_called_with(_TEST_MEDIA_ID)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_send_tts_called_api(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||||
|
) -> None:
|
||||||
|
"""Test the API pipeline calls _send_tts."""
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts"
|
||||||
|
) as mock_send_tts:
|
||||||
|
voice_assistant_api_pipeline._event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
type=PipelineEventType.TTS_END,
|
type=PipelineEventType.TTS_END,
|
||||||
data={
|
data={
|
||||||
|
@ -361,29 +435,36 @@ async def test_send_tts_called(
|
||||||
|
|
||||||
async def test_send_tts_not_called_when_empty(
|
async def test_send_tts_not_called_when_empty(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
|
voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline,
|
||||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
|
||||||
|
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the UDP server with a v1/v2 device doesn't call _send_tts when the output is empty."""
|
"""Test the pipelines do not call _send_tts when the output is empty."""
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPServer._send_tts"
|
"homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts"
|
||||||
) as mock_send_tts:
|
) as mock_send_tts:
|
||||||
voice_assistant_udp_server_v1._event_callback(
|
voice_assistant_udp_pipeline_v1._event_callback(
|
||||||
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
|
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_send_tts.assert_not_called()
|
mock_send_tts.assert_not_called()
|
||||||
|
|
||||||
voice_assistant_udp_server_v2._event_callback(
|
voice_assistant_udp_pipeline_v2._event_callback(
|
||||||
|
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_send_tts.assert_not_called()
|
||||||
|
|
||||||
|
voice_assistant_api_pipeline._event_callback(
|
||||||
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
|
PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}})
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_send_tts.assert_not_called()
|
mock_send_tts.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
async def test_send_tts(
|
async def test_send_tts_udp(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
|
||||||
test_wav,
|
test_wav,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the UDP server calls sendto to transmit audio data to device."""
|
"""Test the UDP server calls sendto to transmit audio data to device."""
|
||||||
|
@ -391,12 +472,12 @@ async def test_send_tts(
|
||||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
||||||
return_value=("wav", test_wav),
|
return_value=("wav", test_wav),
|
||||||
):
|
):
|
||||||
voice_assistant_udp_server_v2.started = True
|
voice_assistant_udp_pipeline_v2.started = True
|
||||||
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
|
voice_assistant_udp_pipeline_v2.transport = Mock(spec=asyncio.DatagramTransport)
|
||||||
with patch.object(
|
with patch.object(
|
||||||
voice_assistant_udp_server_v2.transport, "is_closing", return_value=False
|
voice_assistant_udp_pipeline_v2.transport, "is_closing", return_value=False
|
||||||
):
|
):
|
||||||
voice_assistant_udp_server_v2._event_callback(
|
voice_assistant_udp_pipeline_v2._event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
type=PipelineEventType.TTS_END,
|
type=PipelineEventType.TTS_END,
|
||||||
data={
|
data={
|
||||||
|
@ -408,16 +489,46 @@ async def test_send_tts(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
await voice_assistant_udp_server_v2._tts_done.wait()
|
await voice_assistant_udp_pipeline_v2._tts_done.wait()
|
||||||
|
|
||||||
voice_assistant_udp_server_v2.transport.sendto.assert_called()
|
voice_assistant_udp_pipeline_v2.transport.sendto.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_send_tts_api(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_client: APIClient,
|
||||||
|
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||||
|
test_wav,
|
||||||
|
) -> None:
|
||||||
|
"""Test the API pipeline calls cli.send_voice_assistant_audio to transmit audio data to device."""
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
||||||
|
return_value=("wav", test_wav),
|
||||||
|
):
|
||||||
|
voice_assistant_api_pipeline.started = True
|
||||||
|
|
||||||
|
voice_assistant_api_pipeline._event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.TTS_END,
|
||||||
|
data={
|
||||||
|
"tts_output": {
|
||||||
|
"media_id": _TEST_MEDIA_ID,
|
||||||
|
"url": _TEST_OUTPUT_URL,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
await voice_assistant_api_pipeline._tts_done.wait()
|
||||||
|
|
||||||
|
mock_client.send_voice_assistant_audio.assert_called()
|
||||||
|
|
||||||
|
|
||||||
async def test_send_tts_wrong_sample_rate(
|
async def test_send_tts_wrong_sample_rate(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the UDP server calls sendto to transmit audio data to device."""
|
"""Test that only 16000Hz audio will be streamed."""
|
||||||
with io.BytesIO() as wav_io:
|
with io.BytesIO() as wav_io:
|
||||||
with wave.open(wav_io, "wb") as wav_file:
|
with wave.open(wav_io, "wb") as wav_file:
|
||||||
wav_file.setframerate(22050)
|
wav_file.setframerate(22050)
|
||||||
|
@ -433,10 +544,10 @@ async def test_send_tts_wrong_sample_rate(
|
||||||
),
|
),
|
||||||
pytest.raises(ValueError),
|
pytest.raises(ValueError),
|
||||||
):
|
):
|
||||||
voice_assistant_udp_server_v2.started = True
|
voice_assistant_api_pipeline.started = True
|
||||||
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
|
voice_assistant_api_pipeline.transport = Mock(spec=asyncio.DatagramTransport)
|
||||||
|
|
||||||
voice_assistant_udp_server_v2._event_callback(
|
voice_assistant_api_pipeline._event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
type=PipelineEventType.TTS_END,
|
type=PipelineEventType.TTS_END,
|
||||||
data={
|
data={
|
||||||
|
@ -445,13 +556,13 @@ async def test_send_tts_wrong_sample_rate(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert voice_assistant_udp_server_v2._tts_task is not None
|
assert voice_assistant_api_pipeline._tts_task is not None
|
||||||
await voice_assistant_udp_server_v2._tts_task # raises ValueError
|
await voice_assistant_api_pipeline._tts_task # raises ValueError
|
||||||
|
|
||||||
|
|
||||||
async def test_send_tts_wrong_format(
|
async def test_send_tts_wrong_format(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that only WAV audio will be streamed."""
|
"""Test that only WAV audio will be streamed."""
|
||||||
with (
|
with (
|
||||||
|
@ -461,10 +572,10 @@ async def test_send_tts_wrong_format(
|
||||||
),
|
),
|
||||||
pytest.raises(ValueError),
|
pytest.raises(ValueError),
|
||||||
):
|
):
|
||||||
voice_assistant_udp_server_v2.started = True
|
voice_assistant_api_pipeline.started = True
|
||||||
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
|
voice_assistant_api_pipeline.transport = Mock(spec=asyncio.DatagramTransport)
|
||||||
|
|
||||||
voice_assistant_udp_server_v2._event_callback(
|
voice_assistant_api_pipeline._event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
type=PipelineEventType.TTS_END,
|
type=PipelineEventType.TTS_END,
|
||||||
data={
|
data={
|
||||||
|
@ -473,13 +584,13 @@ async def test_send_tts_wrong_format(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert voice_assistant_udp_server_v2._tts_task is not None
|
assert voice_assistant_api_pipeline._tts_task is not None
|
||||||
await voice_assistant_udp_server_v2._tts_task # raises ValueError
|
await voice_assistant_api_pipeline._tts_task # raises ValueError
|
||||||
|
|
||||||
|
|
||||||
async def test_send_tts_not_started(
|
async def test_send_tts_not_started(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
|
||||||
test_wav,
|
test_wav,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the UDP server does not call sendto when not started."""
|
"""Test the UDP server does not call sendto when not started."""
|
||||||
|
@ -487,10 +598,10 @@ async def test_send_tts_not_started(
|
||||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
||||||
return_value=("wav", test_wav),
|
return_value=("wav", test_wav),
|
||||||
):
|
):
|
||||||
voice_assistant_udp_server_v2.started = False
|
voice_assistant_udp_pipeline_v2.started = False
|
||||||
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
|
voice_assistant_udp_pipeline_v2.transport = Mock(spec=asyncio.DatagramTransport)
|
||||||
|
|
||||||
voice_assistant_udp_server_v2._event_callback(
|
voice_assistant_udp_pipeline_v2._event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
type=PipelineEventType.TTS_END,
|
type=PipelineEventType.TTS_END,
|
||||||
data={
|
data={
|
||||||
|
@ -499,14 +610,41 @@ async def test_send_tts_not_started(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
await voice_assistant_udp_server_v2._tts_done.wait()
|
await voice_assistant_udp_pipeline_v2._tts_done.wait()
|
||||||
|
|
||||||
voice_assistant_udp_server_v2.transport.sendto.assert_not_called()
|
voice_assistant_udp_pipeline_v2.transport.sendto.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_send_tts_transport_none(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline,
|
||||||
|
test_wav,
|
||||||
|
caplog: pytest.LogCaptureFixture,
|
||||||
|
) -> None:
|
||||||
|
"""Test the UDP server does not call sendto when transport is None."""
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
||||||
|
return_value=("wav", test_wav),
|
||||||
|
):
|
||||||
|
voice_assistant_udp_pipeline_v2.started = True
|
||||||
|
voice_assistant_udp_pipeline_v2.transport = None
|
||||||
|
|
||||||
|
voice_assistant_udp_pipeline_v2._event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.TTS_END,
|
||||||
|
data={
|
||||||
|
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await voice_assistant_udp_pipeline_v2._tts_done.wait()
|
||||||
|
|
||||||
|
assert "No transport to send audio to" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
async def test_wake_word(
|
async def test_wake_word(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that the pipeline is set to start with Wake word."""
|
"""Test that the pipeline is set to start with Wake word."""
|
||||||
|
|
||||||
|
@ -520,9 +658,7 @@ async def test_wake_word(
|
||||||
),
|
),
|
||||||
patch("asyncio.Event.wait"), # TTS wait event
|
patch("asyncio.Event.wait"), # TTS wait event
|
||||||
):
|
):
|
||||||
voice_assistant_udp_server_v2.transport = Mock()
|
await voice_assistant_api_pipeline.run_pipeline(
|
||||||
|
|
||||||
await voice_assistant_udp_server_v2.run_pipeline(
|
|
||||||
device_id="mock-device-id",
|
device_id="mock-device-id",
|
||||||
conversation_id=None,
|
conversation_id=None,
|
||||||
flags=2,
|
flags=2,
|
||||||
|
@ -531,7 +667,7 @@ async def test_wake_word(
|
||||||
|
|
||||||
async def test_wake_word_exception(
|
async def test_wake_word_exception(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that the pipeline is set to start with Wake word."""
|
"""Test that the pipeline is set to start with Wake word."""
|
||||||
|
|
||||||
|
@ -542,7 +678,6 @@ async def test_wake_word_exception(
|
||||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
):
|
):
|
||||||
voice_assistant_udp_server_v2.transport = Mock()
|
|
||||||
|
|
||||||
def handle_event(
|
def handle_event(
|
||||||
event_type: VoiceAssistantEventType, data: dict[str, str] | None
|
event_type: VoiceAssistantEventType, data: dict[str, str] | None
|
||||||
|
@ -552,9 +687,9 @@ async def test_wake_word_exception(
|
||||||
assert data["code"] == "pipeline-not-found"
|
assert data["code"] == "pipeline-not-found"
|
||||||
assert data["message"] == "Pipeline not found"
|
assert data["message"] == "Pipeline not found"
|
||||||
|
|
||||||
voice_assistant_udp_server_v2.handle_event = handle_event
|
voice_assistant_api_pipeline.handle_event = handle_event
|
||||||
|
|
||||||
await voice_assistant_udp_server_v2.run_pipeline(
|
await voice_assistant_api_pipeline.run_pipeline(
|
||||||
device_id="mock-device-id",
|
device_id="mock-device-id",
|
||||||
conversation_id=None,
|
conversation_id=None,
|
||||||
flags=2,
|
flags=2,
|
||||||
|
@ -563,7 +698,7 @@ async def test_wake_word_exception(
|
||||||
|
|
||||||
async def test_wake_word_abort_exception(
|
async def test_wake_word_abort_exception(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
voice_assistant_api_pipeline: VoiceAssistantAPIPipeline,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that the pipeline is set to start with Wake word."""
|
"""Test that the pipeline is set to start with Wake word."""
|
||||||
|
|
||||||
|
@ -575,13 +710,9 @@ async def test_wake_word_abort_exception(
|
||||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
),
|
),
|
||||||
patch.object(
|
patch.object(voice_assistant_api_pipeline, "handle_event") as mock_handle_event,
|
||||||
voice_assistant_udp_server_v2, "handle_event"
|
|
||||||
) as mock_handle_event,
|
|
||||||
):
|
):
|
||||||
voice_assistant_udp_server_v2.transport = Mock()
|
await voice_assistant_api_pipeline.run_pipeline(
|
||||||
|
|
||||||
await voice_assistant_udp_server_v2.run_pipeline(
|
|
||||||
device_id="mock-device-id",
|
device_id="mock-device-id",
|
||||||
conversation_id=None,
|
conversation_id=None,
|
||||||
flags=2,
|
flags=2,
|
||||||
|
|
Loading…
Reference in New Issue