diff --git a/homeassistant/components/esphome/binary_sensor.py b/homeassistant/components/esphome/binary_sensor.py index ac0676d8d1e..05ddfc2c43f 100644 --- a/homeassistant/components/esphome/binary_sensor.py +++ b/homeassistant/components/esphome/binary_sensor.py @@ -33,7 +33,9 @@ async def async_setup_entry( entry_data = DomainData.get(hass).get_entry_data(entry) assert entry_data.device_info is not None - if entry_data.device_info.voice_assistant_version: + if entry_data.device_info.voice_assistant_feature_flags_compat( + entry_data.api_version + ): async_add_entities([EsphomeAssistInProgressBinarySensor(entry_data)]) diff --git a/homeassistant/components/esphome/entry_data.py b/homeassistant/components/esphome/entry_data.py index 877c099deee..005963db872 100644 --- a/homeassistant/components/esphome/entry_data.py +++ b/homeassistant/components/esphome/entry_data.py @@ -257,7 +257,9 @@ class RuntimeEntryData: if async_get_dashboard(hass): needed_platforms.add(Platform.UPDATE) - if self.device_info and self.device_info.voice_assistant_version: + if self.device_info and self.device_info.voice_assistant_feature_flags_compat( + self.api_version + ): needed_platforms.add(Platform.BINARY_SENSOR) needed_platforms.add(Platform.SELECT) diff --git a/homeassistant/components/esphome/manager.py b/homeassistant/components/esphome/manager.py index 3813d22ce97..ef56f3a2164 100644 --- a/homeassistant/components/esphome/manager.py +++ b/homeassistant/components/esphome/manager.py @@ -21,6 +21,7 @@ from aioesphomeapi import ( UserService, UserServiceArgType, VoiceAssistantAudioSettings, + VoiceAssistantFeature, ) from awesomeversion import AwesomeVersion import voluptuous as vol @@ -72,7 +73,11 @@ from .domain_data import DomainData # Import config flow so that it's added to the registry from .entry_data import RuntimeEntryData -from .voice_assistant import VoiceAssistantUDPServer +from .voice_assistant import ( + VoiceAssistantAPIPipeline, + VoiceAssistantPipeline, + VoiceAssistantUDPPipeline, +) _LOGGER = logging.getLogger(__name__) @@ -143,7 +148,7 @@ class ESPHomeManager: "cli", "device_id", "domain_data", - "voice_assistant_udp_server", + "voice_assistant_pipeline", "reconnect_logic", "zeroconf_instance", "entry_data", @@ -168,7 +173,7 @@ class ESPHomeManager: self.cli = cli self.device_id: str | None = None self.domain_data = domain_data - self.voice_assistant_udp_server: VoiceAssistantUDPServer | None = None + self.voice_assistant_pipeline: VoiceAssistantPipeline | None = None self.reconnect_logic: ReconnectLogic | None = None self.zeroconf_instance = zeroconf_instance self.entry_data = entry_data @@ -327,9 +332,10 @@ class ESPHomeManager: def _handle_pipeline_finished(self) -> None: self.entry_data.async_set_assist_pipeline_state(False) - if self.voice_assistant_udp_server is not None: - self.voice_assistant_udp_server.close() - self.voice_assistant_udp_server = None + if self.voice_assistant_pipeline is not None: + if isinstance(self.voice_assistant_pipeline, VoiceAssistantUDPPipeline): + self.voice_assistant_pipeline.close() + self.voice_assistant_pipeline = None async def _handle_pipeline_start( self, @@ -339,38 +345,60 @@ class ESPHomeManager: wake_word_phrase: str | None, ) -> int | None: """Start a voice assistant pipeline.""" - if self.voice_assistant_udp_server is not None: + if self.voice_assistant_pipeline is not None: _LOGGER.warning("Voice assistant UDP server was not stopped") - self.voice_assistant_udp_server.stop() - self.voice_assistant_udp_server = None + self.voice_assistant_pipeline.stop() + self.voice_assistant_pipeline = None hass = self.hass - self.voice_assistant_udp_server = VoiceAssistantUDPServer( - hass, - self.entry_data, - self.cli.send_voice_assistant_event, - self._handle_pipeline_finished, - ) - port = await self.voice_assistant_udp_server.start_server() + assert self.entry_data.device_info is not None + if ( + self.entry_data.device_info.voice_assistant_feature_flags_compat( + self.entry_data.api_version + ) + & VoiceAssistantFeature.API_AUDIO + ): + self.voice_assistant_pipeline = VoiceAssistantAPIPipeline( + hass, + self.entry_data, + self.cli.send_voice_assistant_event, + self._handle_pipeline_finished, + self.cli, + ) + port = 0 + else: + self.voice_assistant_pipeline = VoiceAssistantUDPPipeline( + hass, + self.entry_data, + self.cli.send_voice_assistant_event, + self._handle_pipeline_finished, + ) + port = await self.voice_assistant_pipeline.start_server() assert self.device_id is not None, "Device ID must be set" hass.async_create_background_task( - self.voice_assistant_udp_server.run_pipeline( + self.voice_assistant_pipeline.run_pipeline( device_id=self.device_id, conversation_id=conversation_id or None, flags=flags, audio_settings=audio_settings, wake_word_phrase=wake_word_phrase, ), - "esphome.voice_assistant_udp_server.run_pipeline", + "esphome.voice_assistant_pipeline.run_pipeline", ) return port async def _handle_pipeline_stop(self) -> None: """Stop a voice assistant pipeline.""" - if self.voice_assistant_udp_server is not None: - self.voice_assistant_udp_server.stop() + if self.voice_assistant_pipeline is not None: + self.voice_assistant_pipeline.stop() + + async def _handle_audio(self, data: bytes) -> None: + if self.voice_assistant_pipeline is None: + return + assert isinstance(self.voice_assistant_pipeline, VoiceAssistantAPIPipeline) + self.voice_assistant_pipeline.receive_audio_bytes(data) async def on_connect(self) -> None: """Subscribe to states and list entities on successful API login.""" @@ -472,13 +500,23 @@ class ESPHomeManager: ) ) - if device_info.voice_assistant_version: - entry_data.disconnect_callbacks.add( - cli.subscribe_voice_assistant( - self._handle_pipeline_start, - self._handle_pipeline_stop, + flags = device_info.voice_assistant_feature_flags_compat(api_version) + if flags: + if flags & VoiceAssistantFeature.API_AUDIO: + entry_data.disconnect_callbacks.add( + cli.subscribe_voice_assistant( + handle_start=self._handle_pipeline_start, + handle_stop=self._handle_pipeline_stop, + handle_audio=self._handle_audio, + ) + ) + else: + entry_data.disconnect_callbacks.add( + cli.subscribe_voice_assistant( + handle_start=self._handle_pipeline_start, + handle_stop=self._handle_pipeline_stop, + ) ) - ) cli.subscribe_states(entry_data.async_update_state) cli.subscribe_service_calls(self.async_on_service_call) diff --git a/homeassistant/components/esphome/manifest.json b/homeassistant/components/esphome/manifest.json index f1a5333c403..4d5636a6f26 100644 --- a/homeassistant/components/esphome/manifest.json +++ b/homeassistant/components/esphome/manifest.json @@ -15,7 +15,7 @@ "iot_class": "local_push", "loggers": ["aioesphomeapi", "noiseprotocol", "bleak_esphome"], "requirements": [ - "aioesphomeapi==23.2.0", + "aioesphomeapi==24.0.0", "esphome-dashboard-api==1.2.3", "bleak-esphome==1.0.0" ], diff --git a/homeassistant/components/esphome/select.py b/homeassistant/components/esphome/select.py index 07a9d70e558..612ffc4bcc6 100644 --- a/homeassistant/components/esphome/select.py +++ b/homeassistant/components/esphome/select.py @@ -42,7 +42,9 @@ async def async_setup_entry( entry_data = DomainData.get(hass).get_entry_data(entry) assert entry_data.device_info is not None - if entry_data.device_info.voice_assistant_version: + if entry_data.device_info.voice_assistant_feature_flags_compat( + entry_data.api_version + ): async_add_entities( [ EsphomeAssistPipelineSelect(hass, entry_data), diff --git a/homeassistant/components/esphome/voice_assistant.py b/homeassistant/components/esphome/voice_assistant.py index f856cc27179..f9f753389ed 100644 --- a/homeassistant/components/esphome/voice_assistant.py +++ b/homeassistant/components/esphome/voice_assistant.py @@ -11,9 +11,11 @@ from typing import cast import wave from aioesphomeapi import ( + APIClient, VoiceAssistantAudioSettings, VoiceAssistantCommandFlag, VoiceAssistantEventType, + VoiceAssistantFeature, ) from homeassistant.components import stt, tts @@ -64,13 +66,11 @@ _VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[ ) -class VoiceAssistantUDPServer(asyncio.DatagramProtocol): - """Receive UDP packets and forward them to the voice assistant.""" +class VoiceAssistantPipeline: + """Base abstract pipeline class.""" started = False stop_requested = False - transport: asyncio.DatagramTransport | None = None - remote_addr: tuple[str, int] | None = None def __init__( self, @@ -79,12 +79,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None], handle_finished: Callable[[], None], ) -> None: - """Initialize UDP receiver.""" + """Initialize the pipeline.""" self.context = Context() self.hass = hass - - assert entry_data.device_info is not None self.entry_data = entry_data + assert entry_data.device_info is not None self.device_info = entry_data.device_info self.queue: asyncio.Queue[bytes] = asyncio.Queue() @@ -95,69 +94,9 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): @property def is_running(self) -> bool: - """True if the UDP server is started and hasn't been asked to stop.""" + """True if the pipeline is started and hasn't been asked to stop.""" return self.started and (not self.stop_requested) - async def start_server(self) -> int: - """Start accepting connections.""" - - def accept_connection() -> VoiceAssistantUDPServer: - """Accept connection.""" - if self.started: - raise RuntimeError("Can only start once") - if self.stop_requested: - raise RuntimeError("No longer accepting connections") - - self.started = True - return self - - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.setblocking(False) - - sock.bind(("", UDP_PORT)) - - await asyncio.get_running_loop().create_datagram_endpoint( - accept_connection, sock=sock - ) - - return cast(int, sock.getsockname()[1]) - - @callback - def connection_made(self, transport: asyncio.BaseTransport) -> None: - """Store transport for later use.""" - self.transport = cast(asyncio.DatagramTransport, transport) - - @callback - def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: - """Handle incoming UDP packet.""" - if not self.is_running: - return - if self.remote_addr is None: - self.remote_addr = addr - self.queue.put_nowait(data) - - def error_received(self, exc: Exception) -> None: - """Handle when a send or receive operation raises an OSError. - - (Other than BlockingIOError or InterruptedError.) - """ - _LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc) - self.handle_finished() - - @callback - def stop(self) -> None: - """Stop the receiver.""" - self.queue.put_nowait(b"") - self.close() - - def close(self) -> None: - """Close the receiver.""" - self.started = False - self.stop_requested = True - - if self.transport is not None: - self.transport.close() - async def _iterate_packets(self) -> AsyncIterable[bytes]: """Iterate over incoming packets.""" while data := await self.queue.get(): @@ -198,7 +137,12 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): url = async_process_play_media_url(self.hass, path) data_to_send = {"url": url} - if self.device_info.voice_assistant_version >= 2: + if ( + self.device_info.voice_assistant_feature_flags_compat( + self.entry_data.api_version + ) + & VoiceAssistantFeature.SPEAKER + ): media_id = tts_output["media_id"] self._tts_task = self.hass.async_create_background_task( self._send_tts(media_id), "esphome_voice_assistant_tts" @@ -243,9 +187,15 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): if audio_settings is None or audio_settings.volume_multiplier == 0: audio_settings = VoiceAssistantAudioSettings() - tts_audio_output = ( - "wav" if self.device_info.voice_assistant_version >= 2 else "mp3" - ) + if ( + self.device_info.voice_assistant_feature_flags_compat( + self.entry_data.api_version + ) + & VoiceAssistantFeature.SPEAKER + ): + tts_audio_output = "wav" + else: + tts_audio_output = "mp3" _LOGGER.debug("Starting pipeline") if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD: @@ -315,7 +265,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): self.handle_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}) try: - if (not self.is_running) or (self.transport is None): + if not self.is_running: return extension, data = await tts.async_get_media_source_audio( @@ -358,16 +308,133 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): samples_in_chunk = len(chunk) // bytes_per_sample samples_left -= samples_in_chunk - self.transport.sendto(chunk, self.remote_addr) + self.send_audio_bytes(chunk) await asyncio.sleep( samples_in_chunk / stt.AudioSampleRates.SAMPLERATE_16000 * 0.9 ) sample_offset += samples_in_chunk - finally: self.handle_event( VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {} ) self._tts_task = None self._tts_done.set() + + def send_audio_bytes(self, data: bytes) -> None: + """Send bytes to the device.""" + raise NotImplementedError + + def stop(self) -> None: + """Stop the pipeline.""" + self.queue.put_nowait(b"") + + +class VoiceAssistantUDPPipeline(asyncio.DatagramProtocol, VoiceAssistantPipeline): + """Receive UDP packets and forward them to the voice assistant.""" + + transport: asyncio.DatagramTransport | None = None + remote_addr: tuple[str, int] | None = None + + async def start_server(self) -> int: + """Start accepting connections.""" + + def accept_connection() -> VoiceAssistantUDPPipeline: + """Accept connection.""" + if self.started: + raise RuntimeError("Can only start once") + if self.stop_requested: + raise RuntimeError("No longer accepting connections") + + self.started = True + return self + + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.setblocking(False) + + sock.bind(("", UDP_PORT)) + + await asyncio.get_running_loop().create_datagram_endpoint( + accept_connection, sock=sock + ) + + return cast(int, sock.getsockname()[1]) + + @callback + def connection_made(self, transport: asyncio.BaseTransport) -> None: + """Store transport for later use.""" + self.transport = cast(asyncio.DatagramTransport, transport) + + @callback + def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: + """Handle incoming UDP packet.""" + if not self.is_running: + return + if self.remote_addr is None: + self.remote_addr = addr + self.queue.put_nowait(data) + + def error_received(self, exc: Exception) -> None: + """Handle when a send or receive operation raises an OSError. + + (Other than BlockingIOError or InterruptedError.) + """ + _LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc) + self.handle_finished() + + @callback + def stop(self) -> None: + """Stop the receiver.""" + super().stop() + self.close() + + def close(self) -> None: + """Close the receiver.""" + self.started = False + self.stop_requested = True + + if self.transport is not None: + self.transport.close() + + def send_audio_bytes(self, data: bytes) -> None: + """Send bytes to the device via UDP.""" + if self.transport is None: + _LOGGER.error("No transport to send audio to") + return + self.transport.sendto(data, self.remote_addr) + + +class VoiceAssistantAPIPipeline(VoiceAssistantPipeline): + """Send audio to the voice assistant via the API.""" + + def __init__( + self, + hass: HomeAssistant, + entry_data: RuntimeEntryData, + handle_event: Callable[[VoiceAssistantEventType, dict[str, str] | None], None], + handle_finished: Callable[[], None], + api_client: APIClient, + ) -> None: + """Initialize the pipeline.""" + super().__init__(hass, entry_data, handle_event, handle_finished) + self.api_client = api_client + self.started = True + + def send_audio_bytes(self, data: bytes) -> None: + """Send bytes to the device via the API.""" + self.api_client.send_voice_assistant_audio(data) + + @callback + def receive_audio_bytes(self, data: bytes) -> None: + """Receive audio bytes from the device.""" + if not self.is_running: + return + self.queue.put_nowait(data) + + @callback + def stop(self) -> None: + """Stop the pipeline.""" + super().stop() + + self.started = False + self.stop_requested = True diff --git a/requirements_all.txt b/requirements_all.txt index 446d69da244..6c705cb9a18 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -242,7 +242,7 @@ aioelectricitymaps==0.4.0 aioemonitor==1.0.5 # homeassistant.components.esphome -aioesphomeapi==23.2.0 +aioesphomeapi==24.0.0 # homeassistant.components.flo aioflo==2021.11.0 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 29292ae897c..e7b27bfa01a 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -221,7 +221,7 @@ aioelectricitymaps==0.4.0 aioemonitor==1.0.5 # homeassistant.components.esphome -aioesphomeapi==23.2.0 +aioesphomeapi==24.0.0 # homeassistant.components.flo aioflo==2021.11.0 diff --git a/tests/components/esphome/conftest.py b/tests/components/esphome/conftest.py index cb6655f710c..e23f020991d 100644 --- a/tests/components/esphome/conftest.py +++ b/tests/components/esphome/conftest.py @@ -18,6 +18,7 @@ from aioesphomeapi import ( HomeassistantServiceCall, ReconnectLogic, UserService, + VoiceAssistantFeature, ) import pytest from zeroconf import Zeroconf @@ -354,10 +355,16 @@ async def mock_voice_assistant_entry( ): """Set up an ESPHome entry with voice assistant.""" - async def _mock_voice_assistant_entry(version: int) -> MockConfigEntry: + async def _mock_voice_assistant_entry( + voice_assistant_feature_flags: VoiceAssistantFeature, + ) -> MockConfigEntry: return ( await _mock_generic_device_entry( - hass, mock_client, {"voice_assistant_version": version}, ([], []), [] + hass, + mock_client, + {"voice_assistant_feature_flags": voice_assistant_feature_flags}, + ([], []), + [], ) ).entry @@ -367,13 +374,28 @@ async def mock_voice_assistant_entry( @pytest.fixture async def mock_voice_assistant_v1_entry(mock_voice_assistant_entry) -> MockConfigEntry: """Set up an ESPHome entry with voice assistant.""" - return await mock_voice_assistant_entry(version=1) + return await mock_voice_assistant_entry( + voice_assistant_feature_flags=VoiceAssistantFeature.VOICE_ASSISTANT + ) @pytest.fixture async def mock_voice_assistant_v2_entry(mock_voice_assistant_entry) -> MockConfigEntry: """Set up an ESPHome entry with voice assistant.""" - return await mock_voice_assistant_entry(version=2) + return await mock_voice_assistant_entry( + voice_assistant_feature_flags=VoiceAssistantFeature.VOICE_ASSISTANT + | VoiceAssistantFeature.SPEAKER + ) + + +@pytest.fixture +async def mock_voice_assistant_api_entry(mock_voice_assistant_entry) -> MockConfigEntry: + """Set up an ESPHome entry with voice assistant.""" + return await mock_voice_assistant_entry( + voice_assistant_feature_flags=VoiceAssistantFeature.VOICE_ASSISTANT + | VoiceAssistantFeature.SPEAKER + | VoiceAssistantFeature.API_AUDIO + ) @pytest.fixture diff --git a/tests/components/esphome/test_diagnostics.py b/tests/components/esphome/test_diagnostics.py index 0f2b18218ff..1cf4f77875f 100644 --- a/tests/components/esphome/test_diagnostics.py +++ b/tests/components/esphome/test_diagnostics.py @@ -94,7 +94,8 @@ async def test_diagnostics_with_bluetooth( "project_version": "", "suggested_area": "", "uses_password": False, - "voice_assistant_version": 0, + "legacy_voice_assistant_version": 0, + "voice_assistant_feature_flags": 0, "webserver_port": 0, }, "services": [], diff --git a/tests/components/esphome/test_voice_assistant.py b/tests/components/esphome/test_voice_assistant.py index 9882419ed5a..e67d833656e 100644 --- a/tests/components/esphome/test_voice_assistant.py +++ b/tests/components/esphome/test_voice_assistant.py @@ -6,7 +6,7 @@ import socket from unittest.mock import Mock, patch import wave -from aioesphomeapi import VoiceAssistantEventType +from aioesphomeapi import APIClient, VoiceAssistantEventType import pytest from homeassistant.components.assist_pipeline import ( @@ -19,7 +19,10 @@ from homeassistant.components.assist_pipeline.error import ( WakeWordDetectionError, ) from homeassistant.components.esphome import DomainData -from homeassistant.components.esphome.voice_assistant import VoiceAssistantUDPServer +from homeassistant.components.esphome.voice_assistant import ( + VoiceAssistantAPIPipeline, + VoiceAssistantUDPPipeline, +) from homeassistant.core import HomeAssistant _TEST_INPUT_TEXT = "This is an input test" @@ -31,43 +34,54 @@ _ONE_SECOND = 16000 * 2 # 16Khz 16-bit @pytest.fixture -def voice_assistant_udp_server( +def voice_assistant_udp_pipeline( hass: HomeAssistant, -) -> VoiceAssistantUDPServer: - """Return the UDP server factory.""" +) -> VoiceAssistantUDPPipeline: + """Return the UDP pipeline factory.""" def _voice_assistant_udp_server(entry): entry_data = DomainData.get(hass).get_entry_data(entry) - server: VoiceAssistantUDPServer = None + server: VoiceAssistantUDPPipeline = None def handle_finished(): nonlocal server assert server is not None server.close() - server = VoiceAssistantUDPServer(hass, entry_data, Mock(), handle_finished) + server = VoiceAssistantUDPPipeline(hass, entry_data, Mock(), handle_finished) return server # noqa: RET504 return _voice_assistant_udp_server @pytest.fixture -def voice_assistant_udp_server_v1( - voice_assistant_udp_server, - mock_voice_assistant_v1_entry, -) -> VoiceAssistantUDPServer: - """Return the UDP server.""" - return voice_assistant_udp_server(entry=mock_voice_assistant_v1_entry) +def voice_assistant_api_pipeline( + hass: HomeAssistant, + mock_client, + mock_voice_assistant_api_entry, +) -> VoiceAssistantAPIPipeline: + """Return the API Pipeline factory.""" + entry_data = DomainData.get(hass).get_entry_data(mock_voice_assistant_api_entry) + return VoiceAssistantAPIPipeline(hass, entry_data, Mock(), Mock(), mock_client) @pytest.fixture -def voice_assistant_udp_server_v2( - voice_assistant_udp_server, +def voice_assistant_udp_pipeline_v1( + voice_assistant_udp_pipeline, + mock_voice_assistant_v1_entry, +) -> VoiceAssistantUDPPipeline: + """Return the UDP pipeline.""" + return voice_assistant_udp_pipeline(entry=mock_voice_assistant_v1_entry) + + +@pytest.fixture +def voice_assistant_udp_pipeline_v2( + voice_assistant_udp_pipeline, mock_voice_assistant_v2_entry, -) -> VoiceAssistantUDPServer: - """Return the UDP server.""" - return voice_assistant_udp_server(entry=mock_voice_assistant_v2_entry) +) -> VoiceAssistantUDPPipeline: + """Return the UDP pipeline.""" + return voice_assistant_udp_pipeline(entry=mock_voice_assistant_v2_entry) @pytest.fixture @@ -85,7 +99,7 @@ def test_wav() -> bytes: async def test_pipeline_events( hass: HomeAssistant, - voice_assistant_udp_server_v1: VoiceAssistantUDPServer, + voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline, ) -> None: """Test that the pipeline function is called.""" @@ -145,15 +159,15 @@ async def test_pipeline_events( elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: assert data is None - voice_assistant_udp_server_v1.handle_event = handle_event + voice_assistant_udp_pipeline_v1.handle_event = handle_event with patch( "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", new=async_pipeline_from_audio_stream, ): - voice_assistant_udp_server_v1.transport = Mock() + voice_assistant_udp_pipeline_v1.transport = Mock() - await voice_assistant_udp_server_v1.run_pipeline( + await voice_assistant_udp_pipeline_v1.run_pipeline( device_id="mock-device-id", conversation_id=None ) @@ -162,7 +176,7 @@ async def test_udp_server( hass: HomeAssistant, socket_enabled, unused_udp_port_factory, - voice_assistant_udp_server_v1: VoiceAssistantUDPServer, + voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline, ) -> None: """Test the UDP server runs and queues incoming data.""" port_to_use = unused_udp_port_factory() @@ -170,93 +184,133 @@ async def test_udp_server( with patch( "homeassistant.components.esphome.voice_assistant.UDP_PORT", new=port_to_use ): - port = await voice_assistant_udp_server_v1.start_server() + port = await voice_assistant_udp_pipeline_v1.start_server() assert port == port_to_use sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - assert voice_assistant_udp_server_v1.queue.qsize() == 0 + assert voice_assistant_udp_pipeline_v1.queue.qsize() == 0 sock.sendto(b"test", ("127.0.0.1", port)) # Give the socket some time to send/receive the data async with asyncio.timeout(1): - while voice_assistant_udp_server_v1.queue.qsize() == 0: + while voice_assistant_udp_pipeline_v1.queue.qsize() == 0: await asyncio.sleep(0.1) - assert voice_assistant_udp_server_v1.queue.qsize() == 1 + assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1 - voice_assistant_udp_server_v1.stop() - voice_assistant_udp_server_v1.close() + voice_assistant_udp_pipeline_v1.stop() + voice_assistant_udp_pipeline_v1.close() - assert voice_assistant_udp_server_v1.transport.is_closing() + assert voice_assistant_udp_pipeline_v1.transport.is_closing() async def test_udp_server_queue( hass: HomeAssistant, - voice_assistant_udp_server_v1: VoiceAssistantUDPServer, + voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline, ) -> None: """Test the UDP server queues incoming data.""" - voice_assistant_udp_server_v1.started = True + voice_assistant_udp_pipeline_v1.started = True - assert voice_assistant_udp_server_v1.queue.qsize() == 0 + assert voice_assistant_udp_pipeline_v1.queue.qsize() == 0 - voice_assistant_udp_server_v1.datagram_received(bytes(1024), ("localhost", 0)) - assert voice_assistant_udp_server_v1.queue.qsize() == 1 + voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0)) + assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1 - voice_assistant_udp_server_v1.datagram_received(bytes(1024), ("localhost", 0)) - assert voice_assistant_udp_server_v1.queue.qsize() == 2 + voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0)) + assert voice_assistant_udp_pipeline_v1.queue.qsize() == 2 - async for data in voice_assistant_udp_server_v1._iterate_packets(): + async for data in voice_assistant_udp_pipeline_v1._iterate_packets(): assert data == bytes(1024) break - assert voice_assistant_udp_server_v1.queue.qsize() == 1 # One message removed + assert voice_assistant_udp_pipeline_v1.queue.qsize() == 1 # One message removed - voice_assistant_udp_server_v1.stop() + voice_assistant_udp_pipeline_v1.stop() assert ( - voice_assistant_udp_server_v1.queue.qsize() == 2 + voice_assistant_udp_pipeline_v1.queue.qsize() == 2 ) # An empty message added by stop - voice_assistant_udp_server_v1.datagram_received(bytes(1024), ("localhost", 0)) + voice_assistant_udp_pipeline_v1.datagram_received(bytes(1024), ("localhost", 0)) assert ( - voice_assistant_udp_server_v1.queue.qsize() == 2 + voice_assistant_udp_pipeline_v1.queue.qsize() == 2 ) # No new messages added after stop - voice_assistant_udp_server_v1.close() + voice_assistant_udp_pipeline_v1.close() # Stopping the UDP server should cause _iterate_packets to break out # immediately without yielding any data. has_data = False - async for _data in voice_assistant_udp_server_v1._iterate_packets(): + async for _data in voice_assistant_udp_pipeline_v1._iterate_packets(): has_data = True assert not has_data, "Server was stopped" +async def test_api_pipeline_queue( + hass: HomeAssistant, + voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, +) -> None: + """Test the API pipeline queues incoming data.""" + + voice_assistant_api_pipeline.started = True + + assert voice_assistant_api_pipeline.queue.qsize() == 0 + + voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024)) + assert voice_assistant_api_pipeline.queue.qsize() == 1 + + voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024)) + assert voice_assistant_api_pipeline.queue.qsize() == 2 + + async for data in voice_assistant_api_pipeline._iterate_packets(): + assert data == bytes(1024) + break + assert voice_assistant_api_pipeline.queue.qsize() == 1 # One message removed + + voice_assistant_api_pipeline.stop() + assert ( + voice_assistant_api_pipeline.queue.qsize() == 2 + ) # An empty message added by stop + + voice_assistant_api_pipeline.receive_audio_bytes(bytes(1024)) + assert ( + voice_assistant_api_pipeline.queue.qsize() == 2 + ) # No new messages added after stop + + # Stopping the API Pipeline should cause _iterate_packets to break out + # immediately without yielding any data. + has_data = False + async for _data in voice_assistant_api_pipeline._iterate_packets(): + has_data = True + + assert not has_data, "Pipeline was stopped" + + async def test_error_calls_handle_finished( hass: HomeAssistant, - voice_assistant_udp_server_v1: VoiceAssistantUDPServer, + voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline, ) -> None: """Test that the handle_finished callback is called when an error occurs.""" - voice_assistant_udp_server_v1.handle_finished = Mock() + voice_assistant_udp_pipeline_v1.handle_finished = Mock() - voice_assistant_udp_server_v1.error_received(Exception()) + voice_assistant_udp_pipeline_v1.error_received(Exception()) - voice_assistant_udp_server_v1.handle_finished.assert_called() + voice_assistant_udp_pipeline_v1.handle_finished.assert_called() async def test_udp_server_multiple( hass: HomeAssistant, socket_enabled, unused_udp_port_factory, - voice_assistant_udp_server_v1: VoiceAssistantUDPServer, + voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline, ) -> None: """Test that the UDP server raises an error if started twice.""" with patch( "homeassistant.components.esphome.voice_assistant.UDP_PORT", new=unused_udp_port_factory(), ): - await voice_assistant_udp_server_v1.start_server() + await voice_assistant_udp_pipeline_v1.start_server() with ( patch( @@ -265,17 +319,17 @@ async def test_udp_server_multiple( ), pytest.raises(RuntimeError), ): - await voice_assistant_udp_server_v1.start_server() + await voice_assistant_udp_pipeline_v1.start_server() async def test_udp_server_after_stopped( hass: HomeAssistant, socket_enabled, unused_udp_port_factory, - voice_assistant_udp_server_v1: VoiceAssistantUDPServer, + voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline, ) -> None: """Test that the UDP server raises an error if started after stopped.""" - voice_assistant_udp_server_v1.close() + voice_assistant_udp_pipeline_v1.close() with ( patch( "homeassistant.components.esphome.voice_assistant.UDP_PORT", @@ -283,37 +337,37 @@ async def test_udp_server_after_stopped( ), pytest.raises(RuntimeError), ): - await voice_assistant_udp_server_v1.start_server() + await voice_assistant_udp_pipeline_v1.start_server() async def test_unknown_event_type( hass: HomeAssistant, - voice_assistant_udp_server_v1: VoiceAssistantUDPServer, + voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, ) -> None: - """Test the UDP server does not call handle_event for unknown events.""" - voice_assistant_udp_server_v1._event_callback( + """Test the API pipeline does not call handle_event for unknown events.""" + voice_assistant_api_pipeline._event_callback( PipelineEvent( type="unknown-event", data={}, ) ) - assert not voice_assistant_udp_server_v1.handle_event.called + assert not voice_assistant_api_pipeline.handle_event.called async def test_error_event_type( hass: HomeAssistant, - voice_assistant_udp_server_v1: VoiceAssistantUDPServer, + voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, ) -> None: - """Test the UDP server calls event handler with error.""" - voice_assistant_udp_server_v1._event_callback( + """Test the API pipeline calls event handler with error.""" + voice_assistant_api_pipeline._event_callback( PipelineEvent( type=PipelineEventType.ERROR, data={"code": "code", "message": "message"}, ) ) - voice_assistant_udp_server_v1.handle_event.assert_called_with( + voice_assistant_api_pipeline.handle_event.assert_called_with( VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, {"code": "code", "message": "message"}, ) @@ -321,13 +375,13 @@ async def test_error_event_type( async def test_send_tts_not_called( hass: HomeAssistant, - voice_assistant_udp_server_v1: VoiceAssistantUDPServer, + voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline, ) -> None: """Test the UDP server with a v1 device does not call _send_tts.""" with patch( - "homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPServer._send_tts" + "homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts" ) as mock_send_tts: - voice_assistant_udp_server_v1._event_callback( + voice_assistant_udp_pipeline_v1._event_callback( PipelineEvent( type=PipelineEventType.TTS_END, data={ @@ -339,15 +393,35 @@ async def test_send_tts_not_called( mock_send_tts.assert_not_called() -async def test_send_tts_called( +async def test_send_tts_called_udp( hass: HomeAssistant, - voice_assistant_udp_server_v2: VoiceAssistantUDPServer, + voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline, ) -> None: """Test the UDP server with a v2 device calls _send_tts.""" with patch( - "homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPServer._send_tts" + "homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts" ) as mock_send_tts: - voice_assistant_udp_server_v2._event_callback( + voice_assistant_udp_pipeline_v2._event_callback( + PipelineEvent( + type=PipelineEventType.TTS_END, + data={ + "tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL} + }, + ) + ) + + mock_send_tts.assert_called_with(_TEST_MEDIA_ID) + + +async def test_send_tts_called_api( + hass: HomeAssistant, + voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, +) -> None: + """Test the API pipeline calls _send_tts.""" + with patch( + "homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts" + ) as mock_send_tts: + voice_assistant_api_pipeline._event_callback( PipelineEvent( type=PipelineEventType.TTS_END, data={ @@ -361,29 +435,36 @@ async def test_send_tts_called( async def test_send_tts_not_called_when_empty( hass: HomeAssistant, - voice_assistant_udp_server_v1: VoiceAssistantUDPServer, - voice_assistant_udp_server_v2: VoiceAssistantUDPServer, + voice_assistant_udp_pipeline_v1: VoiceAssistantUDPPipeline, + voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline, + voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, ) -> None: - """Test the UDP server with a v1/v2 device doesn't call _send_tts when the output is empty.""" + """Test the pipelines do not call _send_tts when the output is empty.""" with patch( - "homeassistant.components.esphome.voice_assistant.VoiceAssistantUDPServer._send_tts" + "homeassistant.components.esphome.voice_assistant.VoiceAssistantPipeline._send_tts" ) as mock_send_tts: - voice_assistant_udp_server_v1._event_callback( + voice_assistant_udp_pipeline_v1._event_callback( PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}}) ) mock_send_tts.assert_not_called() - voice_assistant_udp_server_v2._event_callback( + voice_assistant_udp_pipeline_v2._event_callback( + PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}}) + ) + + mock_send_tts.assert_not_called() + + voice_assistant_api_pipeline._event_callback( PipelineEvent(type=PipelineEventType.TTS_END, data={"tts_output": {}}) ) mock_send_tts.assert_not_called() -async def test_send_tts( +async def test_send_tts_udp( hass: HomeAssistant, - voice_assistant_udp_server_v2: VoiceAssistantUDPServer, + voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline, test_wav, ) -> None: """Test the UDP server calls sendto to transmit audio data to device.""" @@ -391,12 +472,12 @@ async def test_send_tts( "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", return_value=("wav", test_wav), ): - voice_assistant_udp_server_v2.started = True - voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport) + voice_assistant_udp_pipeline_v2.started = True + voice_assistant_udp_pipeline_v2.transport = Mock(spec=asyncio.DatagramTransport) with patch.object( - voice_assistant_udp_server_v2.transport, "is_closing", return_value=False + voice_assistant_udp_pipeline_v2.transport, "is_closing", return_value=False ): - voice_assistant_udp_server_v2._event_callback( + voice_assistant_udp_pipeline_v2._event_callback( PipelineEvent( type=PipelineEventType.TTS_END, data={ @@ -408,16 +489,46 @@ async def test_send_tts( ) ) - await voice_assistant_udp_server_v2._tts_done.wait() + await voice_assistant_udp_pipeline_v2._tts_done.wait() - voice_assistant_udp_server_v2.transport.sendto.assert_called() + voice_assistant_udp_pipeline_v2.transport.sendto.assert_called() + + +async def test_send_tts_api( + hass: HomeAssistant, + mock_client: APIClient, + voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, + test_wav, +) -> None: + """Test the API pipeline calls cli.send_voice_assistant_audio to transmit audio data to device.""" + with patch( + "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", + return_value=("wav", test_wav), + ): + voice_assistant_api_pipeline.started = True + + voice_assistant_api_pipeline._event_callback( + PipelineEvent( + type=PipelineEventType.TTS_END, + data={ + "tts_output": { + "media_id": _TEST_MEDIA_ID, + "url": _TEST_OUTPUT_URL, + } + }, + ) + ) + + await voice_assistant_api_pipeline._tts_done.wait() + + mock_client.send_voice_assistant_audio.assert_called() async def test_send_tts_wrong_sample_rate( hass: HomeAssistant, - voice_assistant_udp_server_v2: VoiceAssistantUDPServer, + voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, ) -> None: - """Test the UDP server calls sendto to transmit audio data to device.""" + """Test that only 16000Hz audio will be streamed.""" with io.BytesIO() as wav_io: with wave.open(wav_io, "wb") as wav_file: wav_file.setframerate(22050) @@ -433,10 +544,10 @@ async def test_send_tts_wrong_sample_rate( ), pytest.raises(ValueError), ): - voice_assistant_udp_server_v2.started = True - voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport) + voice_assistant_api_pipeline.started = True + voice_assistant_api_pipeline.transport = Mock(spec=asyncio.DatagramTransport) - voice_assistant_udp_server_v2._event_callback( + voice_assistant_api_pipeline._event_callback( PipelineEvent( type=PipelineEventType.TTS_END, data={ @@ -445,13 +556,13 @@ async def test_send_tts_wrong_sample_rate( ) ) - assert voice_assistant_udp_server_v2._tts_task is not None - await voice_assistant_udp_server_v2._tts_task # raises ValueError + assert voice_assistant_api_pipeline._tts_task is not None + await voice_assistant_api_pipeline._tts_task # raises ValueError async def test_send_tts_wrong_format( hass: HomeAssistant, - voice_assistant_udp_server_v2: VoiceAssistantUDPServer, + voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, ) -> None: """Test that only WAV audio will be streamed.""" with ( @@ -461,10 +572,10 @@ async def test_send_tts_wrong_format( ), pytest.raises(ValueError), ): - voice_assistant_udp_server_v2.started = True - voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport) + voice_assistant_api_pipeline.started = True + voice_assistant_api_pipeline.transport = Mock(spec=asyncio.DatagramTransport) - voice_assistant_udp_server_v2._event_callback( + voice_assistant_api_pipeline._event_callback( PipelineEvent( type=PipelineEventType.TTS_END, data={ @@ -473,13 +584,13 @@ async def test_send_tts_wrong_format( ) ) - assert voice_assistant_udp_server_v2._tts_task is not None - await voice_assistant_udp_server_v2._tts_task # raises ValueError + assert voice_assistant_api_pipeline._tts_task is not None + await voice_assistant_api_pipeline._tts_task # raises ValueError async def test_send_tts_not_started( hass: HomeAssistant, - voice_assistant_udp_server_v2: VoiceAssistantUDPServer, + voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline, test_wav, ) -> None: """Test the UDP server does not call sendto when not started.""" @@ -487,10 +598,10 @@ async def test_send_tts_not_started( "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", return_value=("wav", test_wav), ): - voice_assistant_udp_server_v2.started = False - voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport) + voice_assistant_udp_pipeline_v2.started = False + voice_assistant_udp_pipeline_v2.transport = Mock(spec=asyncio.DatagramTransport) - voice_assistant_udp_server_v2._event_callback( + voice_assistant_udp_pipeline_v2._event_callback( PipelineEvent( type=PipelineEventType.TTS_END, data={ @@ -499,14 +610,41 @@ async def test_send_tts_not_started( ) ) - await voice_assistant_udp_server_v2._tts_done.wait() + await voice_assistant_udp_pipeline_v2._tts_done.wait() - voice_assistant_udp_server_v2.transport.sendto.assert_not_called() + voice_assistant_udp_pipeline_v2.transport.sendto.assert_not_called() + + +async def test_send_tts_transport_none( + hass: HomeAssistant, + voice_assistant_udp_pipeline_v2: VoiceAssistantUDPPipeline, + test_wav, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test the UDP server does not call sendto when transport is None.""" + with patch( + "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", + return_value=("wav", test_wav), + ): + voice_assistant_udp_pipeline_v2.started = True + voice_assistant_udp_pipeline_v2.transport = None + + voice_assistant_udp_pipeline_v2._event_callback( + PipelineEvent( + type=PipelineEventType.TTS_END, + data={ + "tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL} + }, + ) + ) + await voice_assistant_udp_pipeline_v2._tts_done.wait() + + assert "No transport to send audio to" in caplog.text async def test_wake_word( hass: HomeAssistant, - voice_assistant_udp_server_v2: VoiceAssistantUDPServer, + voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, ) -> None: """Test that the pipeline is set to start with Wake word.""" @@ -520,9 +658,7 @@ async def test_wake_word( ), patch("asyncio.Event.wait"), # TTS wait event ): - voice_assistant_udp_server_v2.transport = Mock() - - await voice_assistant_udp_server_v2.run_pipeline( + await voice_assistant_api_pipeline.run_pipeline( device_id="mock-device-id", conversation_id=None, flags=2, @@ -531,7 +667,7 @@ async def test_wake_word( async def test_wake_word_exception( hass: HomeAssistant, - voice_assistant_udp_server_v2: VoiceAssistantUDPServer, + voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, ) -> None: """Test that the pipeline is set to start with Wake word.""" @@ -542,7 +678,6 @@ async def test_wake_word_exception( "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", new=async_pipeline_from_audio_stream, ): - voice_assistant_udp_server_v2.transport = Mock() def handle_event( event_type: VoiceAssistantEventType, data: dict[str, str] | None @@ -552,9 +687,9 @@ async def test_wake_word_exception( assert data["code"] == "pipeline-not-found" assert data["message"] == "Pipeline not found" - voice_assistant_udp_server_v2.handle_event = handle_event + voice_assistant_api_pipeline.handle_event = handle_event - await voice_assistant_udp_server_v2.run_pipeline( + await voice_assistant_api_pipeline.run_pipeline( device_id="mock-device-id", conversation_id=None, flags=2, @@ -563,7 +698,7 @@ async def test_wake_word_exception( async def test_wake_word_abort_exception( hass: HomeAssistant, - voice_assistant_udp_server_v2: VoiceAssistantUDPServer, + voice_assistant_api_pipeline: VoiceAssistantAPIPipeline, ) -> None: """Test that the pipeline is set to start with Wake word.""" @@ -575,13 +710,9 @@ async def test_wake_word_abort_exception( "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", new=async_pipeline_from_audio_stream, ), - patch.object( - voice_assistant_udp_server_v2, "handle_event" - ) as mock_handle_event, + patch.object(voice_assistant_api_pipeline, "handle_event") as mock_handle_event, ): - voice_assistant_udp_server_v2.transport = Mock() - - await voice_assistant_udp_server_v2.run_pipeline( + await voice_assistant_api_pipeline.run_pipeline( device_id="mock-device-id", conversation_id=None, flags=2,