"""Test ESPHome voice assistant server.""" import asyncio from collections.abc import Awaitable, Callable from dataclasses import replace import io import socket from unittest.mock import ANY, Mock, patch import wave from aioesphomeapi import ( APIClient, EntityInfo, EntityState, MediaPlayerFormatPurpose, MediaPlayerInfo, MediaPlayerSupportedFormat, UserService, VoiceAssistantAnnounceFinished, VoiceAssistantAudioSettings, VoiceAssistantCommandFlag, VoiceAssistantEventType, VoiceAssistantFeature, VoiceAssistantTimerEventType, ) import pytest from homeassistant.components import assist_satellite, tts from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType from homeassistant.components.assist_satellite import ( AssistSatelliteConfiguration, AssistSatelliteEntity, AssistSatelliteEntityFeature, AssistSatelliteWakeWord, ) # pylint: disable-next=hass-component-root-import from homeassistant.components.assist_satellite.entity import AssistSatelliteState from homeassistant.components.esphome import DOMAIN from homeassistant.components.esphome.assist_satellite import ( EsphomeAssistSatellite, VoiceAssistantUDPServer, ) from homeassistant.components.media_source import PlayMedia from homeassistant.const import STATE_UNAVAILABLE, Platform from homeassistant.core import HomeAssistant from homeassistant.helpers import entity_registry as er, intent as intent_helper import homeassistant.helpers.device_registry as dr from homeassistant.helpers.entity_component import EntityComponent from .conftest import MockESPHomeDevice def get_satellite_entity( hass: HomeAssistant, mac_address: str ) -> EsphomeAssistSatellite | None: """Get the satellite entity for a device.""" ent_reg = er.async_get(hass) satellite_entity_id = ent_reg.async_get_entity_id( Platform.ASSIST_SATELLITE, DOMAIN, f"{mac_address}-assist_satellite" ) if satellite_entity_id is None: return None assert satellite_entity_id.endswith("_assist_satellite") component: EntityComponent[AssistSatelliteEntity] = hass.data[ assist_satellite.DOMAIN ] if (entity := component.get_entity(satellite_entity_id)) is not None: assert isinstance(entity, EsphomeAssistSatellite) return entity return None @pytest.fixture def mock_wav() -> bytes: """Return test WAV audio.""" with io.BytesIO() as wav_io: with wave.open(wav_io, "wb") as wav_file: wav_file.setframerate(16000) wav_file.setsampwidth(2) wav_file.setnchannels(1) wav_file.writeframes(b"test-wav") return wav_io.getvalue() async def test_no_satellite_without_voice_assistant( hass: HomeAssistant, mock_client: APIClient, mock_esphome_device: Callable[ [APIClient, list[EntityInfo], list[UserService], list[EntityState]], Awaitable[MockESPHomeDevice], ], ) -> None: """Test that an assist satellite entity is not created if a voice assistant is not present.""" mock_device: MockESPHomeDevice = await mock_esphome_device( mock_client=mock_client, entity_info=[], user_service=[], states=[], device_info={}, ) await hass.async_block_till_done() # No satellite entity should be created assert get_satellite_entity(hass, mock_device.device_info.mac_address) is None async def test_pipeline_api_audio( hass: HomeAssistant, device_registry: dr.DeviceRegistry, mock_client: APIClient, mock_esphome_device: Callable[ [APIClient, list[EntityInfo], list[UserService], list[EntityState]], Awaitable[MockESPHomeDevice], ], mock_wav: bytes, ) -> None: """Test a complete pipeline run with API audio (over the TCP connection).""" conversation_id = "test-conversation-id" media_url = "http://test.url" media_id = "test-media-id" mock_device: MockESPHomeDevice = await mock_esphome_device( mock_client=mock_client, entity_info=[], user_service=[], states=[], device_info={ "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT | VoiceAssistantFeature.SPEAKER | VoiceAssistantFeature.API_AUDIO }, ) await hass.async_block_till_done() dev = device_registry.async_get_device( connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)} ) satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) assert satellite is not None # Block TTS streaming until we're ready. # This makes it easier to verify the order of pipeline events. stream_tts_audio_ready = asyncio.Event() original_stream_tts_audio = satellite._stream_tts_audio async def _stream_tts_audio(*args, **kwargs): await stream_tts_audio_ready.wait() await original_stream_tts_audio(*args, **kwargs) async def async_pipeline_from_audio_stream(*args, device_id, **kwargs): assert device_id == dev.id stt_stream = kwargs["stt_stream"] chunks = [chunk async for chunk in stt_stream] # Verify test API audio assert chunks == [b"test-mic"] event_callback = kwargs["event_callback"] # Test unknown event type event_callback( PipelineEvent( type="unknown-event", data={}, ) ) mock_client.send_voice_assistant_event.assert_not_called() # Test error event event_callback( PipelineEvent( type=PipelineEventType.ERROR, data={"code": "test-error-code", "message": "test-error-message"}, ) ) assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, {"code": "test-error-code", "message": "test-error-message"}, ) # Wake word assert satellite.state == AssistSatelliteState.IDLE event_callback( PipelineEvent( type=PipelineEventType.WAKE_WORD_START, data={ "entity_id": "test-wake-word-entity-id", "metadata": {}, "timeout": 0, }, ) ) assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START, {}, ) # Test no wake word detected event_callback( PipelineEvent( type=PipelineEventType.WAKE_WORD_END, data={"wake_word_output": {}} ) ) assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( VoiceAssistantEventType.VOICE_ASSISTANT_ERROR, {"code": "no_wake_word", "message": "No wake word detected"}, ) # Correct wake word detection event_callback( PipelineEvent( type=PipelineEventType.WAKE_WORD_END, data={"wake_word_output": {"wake_word_phrase": "test-wake-word"}}, ) ) assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END, {}, ) # STT event_callback( PipelineEvent( type=PipelineEventType.STT_START, data={"engine": "test-stt-engine", "metadata": {}}, ) ) assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( VoiceAssistantEventType.VOICE_ASSISTANT_STT_START, {}, ) assert satellite.state == AssistSatelliteState.LISTENING event_callback( PipelineEvent( type=PipelineEventType.STT_END, data={"stt_output": {"text": "test-stt-text"}}, ) ) assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( VoiceAssistantEventType.VOICE_ASSISTANT_STT_END, {"text": "test-stt-text"}, ) # Intent event_callback( PipelineEvent( type=PipelineEventType.INTENT_START, data={ "engine": "test-intent-engine", "language": hass.config.language, "intent_input": "test-intent-text", "conversation_id": conversation_id, "device_id": device_id, }, ) ) assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START, {}, ) assert satellite.state == AssistSatelliteState.PROCESSING event_callback( PipelineEvent( type=PipelineEventType.INTENT_END, data={"intent_output": {"conversation_id": conversation_id}}, ) ) assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END, {"conversation_id": conversation_id}, ) # TTS event_callback( PipelineEvent( type=PipelineEventType.TTS_START, data={ "engine": "test-stt-engine", "language": hass.config.language, "voice": "test-voice", "tts_input": "test-tts-text", }, ) ) assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START, {"text": "test-tts-text"}, ) assert satellite.state == AssistSatelliteState.RESPONDING # Should return mock_wav audio event_callback( PipelineEvent( type=PipelineEventType.TTS_END, data={"tts_output": {"url": media_url, "media_id": media_id}}, ) ) assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END, {"url": media_url}, ) event_callback(PipelineEvent(type=PipelineEventType.RUN_END)) assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END, {}, ) # Allow TTS streaming to proceed stream_tts_audio_ready.set() pipeline_finished = asyncio.Event() original_handle_pipeline_finished = satellite.handle_pipeline_finished def handle_pipeline_finished(): original_handle_pipeline_finished() pipeline_finished.set() async def async_get_media_source_audio( hass: HomeAssistant, media_source_id: str, ) -> tuple[str, bytes]: return ("wav", mock_wav) tts_finished = asyncio.Event() original_tts_response_finished = satellite.tts_response_finished def tts_response_finished(): original_tts_response_finished() tts_finished.set() with ( patch( "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", new=async_pipeline_from_audio_stream, ), patch( "homeassistant.components.tts.async_get_media_source_audio", new=async_get_media_source_audio, ), patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished), patch.object(satellite, "_stream_tts_audio", _stream_tts_audio), patch.object(satellite, "tts_response_finished", tts_response_finished), ): # Should be cleared at pipeline start satellite._audio_queue.put_nowait(b"leftover-data") # Should be cancelled at pipeline start mock_tts_streaming_task = Mock() satellite._tts_streaming_task = mock_tts_streaming_task async with asyncio.timeout(1): await satellite.handle_pipeline_start( conversation_id=conversation_id, flags=VoiceAssistantCommandFlag.USE_WAKE_WORD, audio_settings=VoiceAssistantAudioSettings(), wake_word_phrase="", ) mock_tts_streaming_task.cancel.assert_called_once() await satellite.handle_audio(b"test-mic") await satellite.handle_pipeline_stop(abort=False) await pipeline_finished.wait() await tts_finished.wait() # Verify TTS streaming events. # These are definitely the last two events because we blocked TTS streaming # until after RUN_END above. assert mock_client.send_voice_assistant_event.call_args_list[-2].args == ( VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}, ) assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}, ) # Verify TTS WAV audio chunk came through mock_client.send_voice_assistant_audio.assert_called_once_with(b"test-wav") @pytest.mark.usefixtures("socket_enabled") async def test_pipeline_udp_audio( hass: HomeAssistant, mock_client: APIClient, mock_esphome_device: Callable[ [APIClient, list[EntityInfo], list[UserService], list[EntityState]], Awaitable[MockESPHomeDevice], ], mock_wav: bytes, ) -> None: """Test a complete pipeline run with legacy UDP audio. This test is not as comprehensive as test_pipeline_api_audio since we're mainly focused on the UDP server. """ conversation_id = "test-conversation-id" media_url = "http://test.url" media_id = "test-media-id" mock_device: MockESPHomeDevice = await mock_esphome_device( mock_client=mock_client, entity_info=[], user_service=[], states=[], device_info={ "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT | VoiceAssistantFeature.SPEAKER }, ) await hass.async_block_till_done() satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) assert satellite is not None mic_audio_event = asyncio.Event() async def async_pipeline_from_audio_stream(*args, device_id, **kwargs): stt_stream = kwargs["stt_stream"] chunks = [] async for chunk in stt_stream: chunks.append(chunk) mic_audio_event.set() # Verify test UDP audio assert chunks == [b"test-mic"] event_callback = kwargs["event_callback"] # STT event_callback( PipelineEvent( type=PipelineEventType.STT_START, data={"engine": "test-stt-engine", "metadata": {}}, ) ) event_callback( PipelineEvent( type=PipelineEventType.STT_END, data={"stt_output": {"text": "test-stt-text"}}, ) ) # Intent event_callback( PipelineEvent( type=PipelineEventType.INTENT_START, data={ "engine": "test-intent-engine", "language": hass.config.language, "intent_input": "test-intent-text", "conversation_id": conversation_id, "device_id": device_id, }, ) ) event_callback( PipelineEvent( type=PipelineEventType.INTENT_END, data={"intent_output": {"conversation_id": conversation_id}}, ) ) # TTS event_callback( PipelineEvent( type=PipelineEventType.TTS_START, data={ "engine": "test-stt-engine", "language": hass.config.language, "voice": "test-voice", "tts_input": "test-tts-text", }, ) ) # Should return mock_wav audio event_callback( PipelineEvent( type=PipelineEventType.TTS_END, data={"tts_output": {"url": media_url, "media_id": media_id}}, ) ) event_callback(PipelineEvent(type=PipelineEventType.RUN_END)) pipeline_finished = asyncio.Event() original_handle_pipeline_finished = satellite.handle_pipeline_finished def handle_pipeline_finished(): original_handle_pipeline_finished() pipeline_finished.set() async def async_get_media_source_audio( hass: HomeAssistant, media_source_id: str, ) -> tuple[str, bytes]: return ("wav", mock_wav) tts_finished = asyncio.Event() original_tts_response_finished = satellite.tts_response_finished def tts_response_finished(): original_tts_response_finished() tts_finished.set() class TestProtocol(asyncio.DatagramProtocol): def __init__(self) -> None: self.transport = None self.data_received: list[bytes] = [] def connection_made(self, transport): self.transport = transport def datagram_received(self, data: bytes, addr): self.data_received.append(data) with ( patch( "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", new=async_pipeline_from_audio_stream, ), patch( "homeassistant.components.tts.async_get_media_source_audio", new=async_get_media_source_audio, ), patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished), patch.object(satellite, "tts_response_finished", tts_response_finished), ): async with asyncio.timeout(1): port = await satellite.handle_pipeline_start( conversation_id=conversation_id, flags=VoiceAssistantCommandFlag(0), # stt audio_settings=VoiceAssistantAudioSettings(), wake_word_phrase="", ) assert (port is not None) and (port > 0) ( transport, protocol, ) = await asyncio.get_running_loop().create_datagram_endpoint( TestProtocol, remote_addr=("127.0.0.1", port) ) assert isinstance(protocol, TestProtocol) # Send audio over UDP transport.sendto(b"test-mic") # Wait for audio chunk to be delivered await mic_audio_event.wait() await satellite.handle_pipeline_stop(abort=False) await pipeline_finished.wait() await tts_finished.wait() # Verify TTS audio (from UDP) assert protocol.data_received == [b"test-wav"] # Check that UDP server was stopped sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.setblocking(False) sock.bind(("", port)) # will fail if UDP server is still running sock.close() async def test_udp_errors() -> None: """Test UDP protocol error conditions.""" audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue() protocol = VoiceAssistantUDPServer(audio_queue) protocol.datagram_received(b"test", ("", 0)) assert audio_queue.qsize() == 1 assert (await audio_queue.get()) == b"test" # None will stop the pipeline protocol.error_received(RuntimeError()) assert audio_queue.qsize() == 1 assert (await audio_queue.get()) is None # No transport assert protocol.transport is None protocol.send_audio_bytes(b"test") # No remote address protocol.transport = Mock() protocol.remote_addr = None protocol.send_audio_bytes(b"test") protocol.transport.sendto.assert_not_called() async def test_pipeline_media_player( hass: HomeAssistant, mock_client: APIClient, mock_esphome_device: Callable[ [APIClient, list[EntityInfo], list[UserService], list[EntityState]], Awaitable[MockESPHomeDevice], ], mock_wav: bytes, ) -> None: """Test a complete pipeline run with the TTS response sent to a media player instead of a speaker. This test is not as comprehensive as test_pipeline_api_audio since we're mainly focused on tts_response_finished getting automatically called. """ conversation_id = "test-conversation-id" media_url = "http://test.url" media_id = "test-media-id" mock_device: MockESPHomeDevice = await mock_esphome_device( mock_client=mock_client, entity_info=[], user_service=[], states=[], device_info={ "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT | VoiceAssistantFeature.API_AUDIO }, ) await hass.async_block_till_done() satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) assert satellite is not None async def async_pipeline_from_audio_stream(*args, device_id, **kwargs): stt_stream = kwargs["stt_stream"] async for _chunk in stt_stream: break event_callback = kwargs["event_callback"] # STT event_callback( PipelineEvent( type=PipelineEventType.STT_START, data={"engine": "test-stt-engine", "metadata": {}}, ) ) event_callback( PipelineEvent( type=PipelineEventType.STT_END, data={"stt_output": {"text": "test-stt-text"}}, ) ) # Intent event_callback( PipelineEvent( type=PipelineEventType.INTENT_START, data={ "engine": "test-intent-engine", "language": hass.config.language, "intent_input": "test-intent-text", "conversation_id": conversation_id, "device_id": device_id, }, ) ) event_callback( PipelineEvent( type=PipelineEventType.INTENT_END, data={"intent_output": {"conversation_id": conversation_id}}, ) ) # TTS event_callback( PipelineEvent( type=PipelineEventType.TTS_START, data={ "engine": "test-stt-engine", "language": hass.config.language, "voice": "test-voice", "tts_input": "test-tts-text", }, ) ) # Should return mock_wav audio event_callback( PipelineEvent( type=PipelineEventType.TTS_END, data={"tts_output": {"url": media_url, "media_id": media_id}}, ) ) event_callback(PipelineEvent(type=PipelineEventType.RUN_END)) pipeline_finished = asyncio.Event() original_handle_pipeline_finished = satellite.handle_pipeline_finished def handle_pipeline_finished(): original_handle_pipeline_finished() pipeline_finished.set() async def async_get_media_source_audio( hass: HomeAssistant, media_source_id: str, ) -> tuple[str, bytes]: return ("wav", mock_wav) tts_finished = asyncio.Event() original_tts_response_finished = satellite.tts_response_finished def tts_response_finished(): original_tts_response_finished() tts_finished.set() with ( patch( "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", new=async_pipeline_from_audio_stream, ), patch( "homeassistant.components.tts.async_get_media_source_audio", new=async_get_media_source_audio, ), patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished), patch.object(satellite, "tts_response_finished", tts_response_finished), ): async with asyncio.timeout(1): await satellite.handle_pipeline_start( conversation_id=conversation_id, flags=VoiceAssistantCommandFlag(0), # stt audio_settings=VoiceAssistantAudioSettings(), wake_word_phrase="", ) await satellite.handle_pipeline_stop(abort=False) await pipeline_finished.wait() assert satellite.state == AssistSatelliteState.RESPONDING # Will trigger tts_response_finished await mock_device.mock_voice_assistant_handle_announcement_finished( VoiceAssistantAnnounceFinished(success=True) ) await tts_finished.wait() assert satellite.state == AssistSatelliteState.IDLE async def test_timer_events( hass: HomeAssistant, device_registry: dr.DeviceRegistry, mock_client: APIClient, mock_esphome_device: Callable[ [APIClient, list[EntityInfo], list[UserService], list[EntityState]], Awaitable[MockESPHomeDevice], ], ) -> None: """Test that injecting timer events results in the correct api client calls.""" mock_device: MockESPHomeDevice = await mock_esphome_device( mock_client=mock_client, entity_info=[], user_service=[], states=[], device_info={ "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT | VoiceAssistantFeature.TIMERS }, ) await hass.async_block_till_done() dev = device_registry.async_get_device( connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)} ) total_seconds = (1 * 60 * 60) + (2 * 60) + 3 await intent_helper.async_handle( hass, "test", intent_helper.INTENT_START_TIMER, { "name": {"value": "test timer"}, "hours": {"value": 1}, "minutes": {"value": 2}, "seconds": {"value": 3}, }, device_id=dev.id, ) mock_client.send_voice_assistant_timer_event.assert_called_with( VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED, ANY, "test timer", total_seconds, total_seconds, True, ) # Increase timer beyond original time and check total_seconds has increased mock_client.send_voice_assistant_timer_event.reset_mock() total_seconds += 5 * 60 await intent_helper.async_handle( hass, "test", intent_helper.INTENT_INCREASE_TIMER, { "name": {"value": "test timer"}, "minutes": {"value": 5}, }, device_id=dev.id, ) mock_client.send_voice_assistant_timer_event.assert_called_with( VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED, ANY, "test timer", total_seconds, ANY, True, ) async def test_unknown_timer_event( hass: HomeAssistant, device_registry: dr.DeviceRegistry, mock_client: APIClient, mock_esphome_device: Callable[ [APIClient, list[EntityInfo], list[UserService], list[EntityState]], Awaitable[MockESPHomeDevice], ], ) -> None: """Test that unknown (new) timer event types do not result in api calls.""" mock_device: MockESPHomeDevice = await mock_esphome_device( mock_client=mock_client, entity_info=[], user_service=[], states=[], device_info={ "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT | VoiceAssistantFeature.TIMERS }, ) await hass.async_block_till_done() assert mock_device.entry.unique_id is not None dev = device_registry.async_get_device( connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)} ) assert dev is not None with patch( "homeassistant.components.esphome.assist_satellite._TIMER_EVENT_TYPES.from_hass", side_effect=KeyError, ): await intent_helper.async_handle( hass, "test", intent_helper.INTENT_START_TIMER, { "name": {"value": "test timer"}, "hours": {"value": 1}, "minutes": {"value": 2}, "seconds": {"value": 3}, }, device_id=dev.id, ) mock_client.send_voice_assistant_timer_event.assert_not_called() async def test_streaming_tts_errors( hass: HomeAssistant, mock_client: APIClient, mock_esphome_device: Callable[ [APIClient, list[EntityInfo], list[UserService], list[EntityState]], Awaitable[MockESPHomeDevice], ], mock_wav: bytes, ) -> None: """Test error conditions for _stream_tts_audio function.""" mock_device: MockESPHomeDevice = await mock_esphome_device( mock_client=mock_client, entity_info=[], user_service=[], states=[], device_info={ "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT }, ) await hass.async_block_till_done() satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) assert satellite is not None # Should not stream if not running satellite._is_running = False await satellite._stream_tts_audio("test-media-id") mock_client.send_voice_assistant_audio.assert_not_called() satellite._is_running = True # Should only stream WAV async def get_mp3( hass: HomeAssistant, media_source_id: str, ) -> tuple[str, bytes]: return ("mp3", b"") with patch( "homeassistant.components.tts.async_get_media_source_audio", new=get_mp3 ): await satellite._stream_tts_audio("test-media-id") mock_client.send_voice_assistant_audio.assert_not_called() # Needs to be the correct sample rate, etc. async def get_bad_wav( hass: HomeAssistant, media_source_id: str, ) -> tuple[str, bytes]: with io.BytesIO() as wav_io: with wave.open(wav_io, "wb") as wav_file: wav_file.setframerate(48000) wav_file.setsampwidth(2) wav_file.setnchannels(1) wav_file.writeframes(b"test-wav") return ("wav", wav_io.getvalue()) with patch( "homeassistant.components.tts.async_get_media_source_audio", new=get_bad_wav ): await satellite._stream_tts_audio("test-media-id") mock_client.send_voice_assistant_audio.assert_not_called() # Check that TTS_STREAM_* events still get sent after cancel media_fetched = asyncio.Event() async def get_slow_wav( hass: HomeAssistant, media_source_id: str, ) -> tuple[str, bytes]: media_fetched.set() await asyncio.sleep(1) return ("wav", mock_wav) mock_client.send_voice_assistant_event.reset_mock() with patch( "homeassistant.components.tts.async_get_media_source_audio", new=get_slow_wav ): task = asyncio.create_task(satellite._stream_tts_audio("test-media-id")) async with asyncio.timeout(1): # Wait for media to be fetched await media_fetched.wait() # Cancel task task.cancel() await task # No audio should have gone out mock_client.send_voice_assistant_audio.assert_not_called() assert len(mock_client.send_voice_assistant_event.call_args_list) == 2 # The TTS_STREAM_* events should have gone out assert mock_client.send_voice_assistant_event.call_args_list[-2].args == ( VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}, ) assert mock_client.send_voice_assistant_event.call_args_list[-1].args == ( VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {}, ) async def test_tts_format_from_media_player( hass: HomeAssistant, mock_client: APIClient, mock_esphome_device: Callable[ [APIClient, list[EntityInfo], list[UserService], list[EntityState]], Awaitable[MockESPHomeDevice], ], ) -> None: """Test that the text-to-speech format is pulled from the first media player.""" mock_device: MockESPHomeDevice = await mock_esphome_device( mock_client=mock_client, entity_info=[ MediaPlayerInfo( object_id="mymedia_player", key=1, name="my media_player", unique_id="my_media_player", supports_pause=True, supported_formats=[ MediaPlayerSupportedFormat( format="flac", sample_rate=48000, num_channels=2, purpose=MediaPlayerFormatPurpose.DEFAULT, sample_bytes=2, ), # This is the format that should be used for tts MediaPlayerSupportedFormat( format="mp3", sample_rate=22050, num_channels=1, purpose=MediaPlayerFormatPurpose.ANNOUNCEMENT, sample_bytes=2, ), ], ) ], user_service=[], states=[], device_info={ "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT }, ) await hass.async_block_till_done() satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) assert satellite is not None with patch( "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", ) as mock_pipeline_from_audio_stream: await satellite.handle_pipeline_start( conversation_id="", flags=0, audio_settings=VoiceAssistantAudioSettings(), wake_word_phrase=None, ) mock_pipeline_from_audio_stream.assert_called_once() kwargs = mock_pipeline_from_audio_stream.call_args_list[0].kwargs # Should be ANNOUNCEMENT format from media player assert kwargs.get("tts_audio_output") == { tts.ATTR_PREFERRED_FORMAT: "mp3", tts.ATTR_PREFERRED_SAMPLE_RATE: 22050, tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1, tts.ATTR_PREFERRED_SAMPLE_BYTES: 2, } async def test_tts_minimal_format_from_media_player( hass: HomeAssistant, mock_client: APIClient, mock_esphome_device: Callable[ [APIClient, list[EntityInfo], list[UserService], list[EntityState]], Awaitable[MockESPHomeDevice], ], ) -> None: """Test text-to-speech format when media player only specifies the codec.""" mock_device: MockESPHomeDevice = await mock_esphome_device( mock_client=mock_client, entity_info=[ MediaPlayerInfo( object_id="mymedia_player", key=1, name="my media_player", unique_id="my_media_player", supports_pause=True, supported_formats=[ MediaPlayerSupportedFormat( format="flac", sample_rate=48000, num_channels=2, purpose=MediaPlayerFormatPurpose.DEFAULT, sample_bytes=2, ), # This is the format that should be used for tts MediaPlayerSupportedFormat( format="mp3", sample_rate=0, # source rate num_channels=0, # source channels purpose=MediaPlayerFormatPurpose.ANNOUNCEMENT, sample_bytes=0, # source width ), ], ) ], user_service=[], states=[], device_info={ "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT }, ) await hass.async_block_till_done() satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) assert satellite is not None with patch( "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", ) as mock_pipeline_from_audio_stream: await satellite.handle_pipeline_start( conversation_id="", flags=0, audio_settings=VoiceAssistantAudioSettings(), wake_word_phrase=None, ) mock_pipeline_from_audio_stream.assert_called_once() kwargs = mock_pipeline_from_audio_stream.call_args_list[0].kwargs # Should be ANNOUNCEMENT format from media player assert kwargs.get("tts_audio_output") == { tts.ATTR_PREFERRED_FORMAT: "mp3", } async def test_announce_supported_features( hass: HomeAssistant, mock_client: APIClient, mock_esphome_device: Callable[ [APIClient, list[EntityInfo], list[UserService], list[EntityState]], Awaitable[MockESPHomeDevice], ], ) -> None: """Test that the announce supported feature is set by flags.""" mock_device: MockESPHomeDevice = await mock_esphome_device( mock_client=mock_client, entity_info=[], user_service=[], states=[], device_info={ "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT }, ) await hass.async_block_till_done() satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) assert satellite is not None assert not (satellite.supported_features & AssistSatelliteEntityFeature.ANNOUNCE) async def test_announce_message( hass: HomeAssistant, mock_client: APIClient, mock_esphome_device: Callable[ [APIClient, list[EntityInfo], list[UserService], list[EntityState]], Awaitable[MockESPHomeDevice], ], ) -> None: """Test announcement with message.""" mock_device: MockESPHomeDevice = await mock_esphome_device( mock_client=mock_client, entity_info=[], user_service=[], states=[], device_info={ "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT | VoiceAssistantFeature.SPEAKER | VoiceAssistantFeature.API_AUDIO | VoiceAssistantFeature.ANNOUNCE }, ) await hass.async_block_till_done() satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) assert satellite is not None done = asyncio.Event() async def send_voice_assistant_announcement_await_response( media_id: str, timeout: float, text: str ): assert satellite.state == AssistSatelliteState.RESPONDING assert media_id == "https://www.home-assistant.io/resolved.mp3" assert text == "test-text" done.set() with ( patch( "homeassistant.components.assist_satellite.entity.tts_generate_media_source_id", return_value="media-source://bla", ), patch( "homeassistant.components.media_source.async_resolve_media", return_value=PlayMedia( url="https://www.home-assistant.io/resolved.mp3", mime_type="audio/mp3", ), ), patch.object( mock_client, "send_voice_assistant_announcement_await_response", new=send_voice_assistant_announcement_await_response, ), ): async with asyncio.timeout(1): await hass.services.async_call( assist_satellite.DOMAIN, "announce", {"entity_id": satellite.entity_id, "message": "test-text"}, blocking=True, ) await done.wait() assert satellite.state == AssistSatelliteState.IDLE async def test_announce_media_id( hass: HomeAssistant, mock_client: APIClient, mock_esphome_device: Callable[ [APIClient, list[EntityInfo], list[UserService], list[EntityState]], Awaitable[MockESPHomeDevice], ], device_registry: dr.DeviceRegistry, ) -> None: """Test announcement with media id.""" mock_device: MockESPHomeDevice = await mock_esphome_device( mock_client=mock_client, entity_info=[ MediaPlayerInfo( object_id="mymedia_player", key=1, name="my media_player", unique_id="my_media_player", supports_pause=True, supported_formats=[ MediaPlayerSupportedFormat( format="flac", sample_rate=48000, num_channels=2, purpose=MediaPlayerFormatPurpose.ANNOUNCEMENT, sample_bytes=2, ), ], ) ], user_service=[], states=[], device_info={ "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT | VoiceAssistantFeature.SPEAKER | VoiceAssistantFeature.API_AUDIO | VoiceAssistantFeature.ANNOUNCE }, ) await hass.async_block_till_done() dev = device_registry.async_get_device( connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)} ) satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) assert satellite is not None done = asyncio.Event() async def send_voice_assistant_announcement_await_response( media_id: str, timeout: float, text: str ): assert satellite.state == AssistSatelliteState.RESPONDING assert media_id == "https://www.home-assistant.io/proxied.flac" done.set() with ( patch.object( mock_client, "send_voice_assistant_announcement_await_response", new=send_voice_assistant_announcement_await_response, ), patch( "homeassistant.components.esphome.assist_satellite.async_create_proxy_url", return_value="https://www.home-assistant.io/proxied.flac", ) as mock_async_create_proxy_url, ): async with asyncio.timeout(1): await hass.services.async_call( assist_satellite.DOMAIN, "announce", { "entity_id": satellite.entity_id, "media_id": "https://www.home-assistant.io/resolved.mp3", }, blocking=True, ) await done.wait() assert satellite.state == AssistSatelliteState.IDLE mock_async_create_proxy_url.assert_called_once_with( hass, dev.id, "https://www.home-assistant.io/resolved.mp3", media_format="flac", rate=48000, channels=2, width=2, ) async def test_satellite_unloaded_on_disconnect( hass: HomeAssistant, mock_client: APIClient, mock_esphome_device: Callable[ [APIClient, list[EntityInfo], list[UserService], list[EntityState]], Awaitable[MockESPHomeDevice], ], ) -> None: """Test that the assist satellite platform is unloaded on disconnect.""" mock_device: MockESPHomeDevice = await mock_esphome_device( mock_client=mock_client, entity_info=[], user_service=[], states=[], device_info={ "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT }, ) await hass.async_block_till_done() satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) assert satellite is not None state = hass.states.get(satellite.entity_id) assert state is not None assert state.state != STATE_UNAVAILABLE # Device will be unavailable after disconnect await mock_device.mock_disconnect(True) state = hass.states.get(satellite.entity_id) assert state is not None assert state.state == STATE_UNAVAILABLE async def test_pipeline_abort( hass: HomeAssistant, mock_client: APIClient, mock_esphome_device: Callable[ [APIClient, list[EntityInfo], list[UserService], list[EntityState]], Awaitable[MockESPHomeDevice], ], ) -> None: """Test aborting a pipeline (no further processing).""" mock_device: MockESPHomeDevice = await mock_esphome_device( mock_client=mock_client, entity_info=[], user_service=[], states=[], device_info={ "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT | VoiceAssistantFeature.API_AUDIO }, ) await hass.async_block_till_done() satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) assert satellite is not None chunks = [] chunk_received = asyncio.Event() pipeline_aborted = asyncio.Event() async def async_pipeline_from_audio_stream(*args, **kwargs): stt_stream = kwargs["stt_stream"] try: async for chunk in stt_stream: chunks.append(chunk) chunk_received.set() except asyncio.CancelledError: # Aborting cancels the pipeline task pipeline_aborted.set() raise pipeline_finished = asyncio.Event() original_handle_pipeline_finished = satellite.handle_pipeline_finished def handle_pipeline_finished(): original_handle_pipeline_finished() pipeline_finished.set() with ( patch( "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", new=async_pipeline_from_audio_stream, ), patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished), ): async with asyncio.timeout(1): await satellite.handle_pipeline_start( conversation_id="", flags=VoiceAssistantCommandFlag(0), # stt audio_settings=VoiceAssistantAudioSettings(), wake_word_phrase="", ) await satellite.handle_audio(b"before-abort") await chunk_received.wait() # Abort the pipeline, no further processing await satellite.handle_pipeline_stop(abort=True) await pipeline_aborted.wait() # This chunk should not make it into the STT stream await satellite.handle_audio(b"after-abort") await pipeline_finished.wait() # Only first chunk assert chunks == [b"before-abort"] async def test_get_set_configuration( hass: HomeAssistant, mock_client: APIClient, mock_esphome_device: Callable[ [APIClient, list[EntityInfo], list[UserService], list[EntityState]], Awaitable[MockESPHomeDevice], ], ) -> None: """Test getting and setting the satellite configuration.""" expected_config = AssistSatelliteConfiguration( available_wake_words=[ AssistSatelliteWakeWord("1234", "okay nabu", ["en"]), AssistSatelliteWakeWord("5678", "hey jarvis", ["en"]), ], active_wake_words=["1234"], max_active_wake_words=1, ) mock_client.get_voice_assistant_configuration.return_value = expected_config mock_device: MockESPHomeDevice = await mock_esphome_device( mock_client=mock_client, entity_info=[], user_service=[], states=[], device_info={ "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT | VoiceAssistantFeature.ANNOUNCE }, ) await hass.async_block_till_done() satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) assert satellite is not None # HA should have been updated actual_config = satellite.async_get_configuration() assert actual_config == expected_config updated_config = replace(actual_config, active_wake_words=["5678"]) mock_client.get_voice_assistant_configuration.return_value = updated_config # Change active wake words await satellite.async_set_configuration(updated_config) # Set config method should be called mock_client.set_voice_assistant_configuration.assert_called_once_with( active_wake_words=["5678"] ) # Device should have been updated assert satellite.async_get_configuration() == updated_config