"""Support for assist satellites in ESPHome.""" from __future__ import annotations import asyncio from collections.abc import AsyncIterable from functools import partial import io from itertools import chain import logging import socket from typing import Any, cast import wave from aioesphomeapi import ( MediaPlayerFormatPurpose, MediaPlayerSupportedFormat, VoiceAssistantAnnounceFinished, VoiceAssistantAudioSettings, VoiceAssistantCommandFlag, VoiceAssistantEventType, VoiceAssistantFeature, VoiceAssistantTimerEventType, ) from homeassistant.components import assist_satellite, tts from homeassistant.components.assist_pipeline import ( PipelineEvent, PipelineEventType, PipelineStage, ) from homeassistant.components.intent import ( TimerEventType, TimerInfo, async_register_timer_handler, ) from homeassistant.components.media_player import async_process_play_media_url from homeassistant.config_entries import ConfigEntry from homeassistant.const import EntityCategory, Platform from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import entity_registry as er from homeassistant.helpers.entity_platform import AddEntitiesCallback from .const import DOMAIN from .entity import EsphomeAssistEntity from .entry_data import ESPHomeConfigEntry, RuntimeEntryData from .enum_mapper import EsphomeEnumMapper from .ffmpeg_proxy import async_create_proxy_url _LOGGER = logging.getLogger(__name__) _VOICE_ASSISTANT_EVENT_TYPES: EsphomeEnumMapper[ VoiceAssistantEventType, PipelineEventType ] = EsphomeEnumMapper( { VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: PipelineEventType.ERROR, VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START: PipelineEventType.RUN_START, VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END: PipelineEventType.RUN_END, VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: PipelineEventType.STT_START, VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: PipelineEventType.STT_END, VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START: PipelineEventType.INTENT_START, VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: PipelineEventType.INTENT_END, VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: PipelineEventType.TTS_START, VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: PipelineEventType.TTS_END, VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START: PipelineEventType.WAKE_WORD_START, VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: PipelineEventType.WAKE_WORD_END, VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_START: PipelineEventType.STT_VAD_START, VoiceAssistantEventType.VOICE_ASSISTANT_STT_VAD_END: PipelineEventType.STT_VAD_END, } ) _TIMER_EVENT_TYPES: EsphomeEnumMapper[VoiceAssistantTimerEventType, TimerEventType] = ( EsphomeEnumMapper( { VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED: TimerEventType.STARTED, VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED: TimerEventType.UPDATED, VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_CANCELLED: TimerEventType.CANCELLED, VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_FINISHED: TimerEventType.FINISHED, } ) ) _ANNOUNCEMENT_TIMEOUT_SEC = 5 * 60 # 5 minutes _CONFIG_TIMEOUT_SEC = 5 async def async_setup_entry( hass: HomeAssistant, entry: ESPHomeConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: """Set up Assist satellite entity.""" entry_data = entry.runtime_data assert entry_data.device_info is not None if entry_data.device_info.voice_assistant_feature_flags_compat( entry_data.api_version ): async_add_entities( [ EsphomeAssistSatellite(entry, entry_data), ] ) class EsphomeAssistSatellite( EsphomeAssistEntity, assist_satellite.AssistSatelliteEntity ): """Satellite running ESPHome.""" entity_description = assist_satellite.AssistSatelliteEntityDescription( key="assist_satellite", translation_key="assist_satellite", entity_category=EntityCategory.CONFIG, ) def __init__( self, config_entry: ConfigEntry, entry_data: RuntimeEntryData, ) -> None: """Initialize satellite.""" super().__init__(entry_data) self.config_entry = config_entry self.entry_data = entry_data self.cli = self.entry_data.client self._is_running: bool = True self._pipeline_task: asyncio.Task | None = None self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue() self._tts_streaming_task: asyncio.Task | None = None self._udp_server: VoiceAssistantUDPServer | None = None # Empty config. Updated when added to HA. self._satellite_config = assist_satellite.AssistSatelliteConfiguration( available_wake_words=[], active_wake_words=[], max_active_wake_words=0 ) @property def pipeline_entity_id(self) -> str | None: """Return the entity ID of the pipeline to use for the next conversation.""" assert self.entry_data.device_info is not None ent_reg = er.async_get(self.hass) return ent_reg.async_get_entity_id( Platform.SELECT, DOMAIN, f"{self.entry_data.device_info.mac_address}-pipeline", ) @property def vad_sensitivity_entity_id(self) -> str | None: """Return the entity ID of the VAD sensitivity to use for the next conversation.""" assert self.entry_data.device_info is not None ent_reg = er.async_get(self.hass) return ent_reg.async_get_entity_id( Platform.SELECT, DOMAIN, f"{self.entry_data.device_info.mac_address}-vad_sensitivity", ) @callback def async_get_configuration( self, ) -> assist_satellite.AssistSatelliteConfiguration: """Get the current satellite configuration.""" return self._satellite_config async def async_set_configuration( self, config: assist_satellite.AssistSatelliteConfiguration ) -> None: """Set the current satellite configuration.""" await self.cli.set_voice_assistant_configuration( active_wake_words=config.active_wake_words ) _LOGGER.debug("Set active wake words: %s", config.active_wake_words) # Ensure configuration is updated await self._update_satellite_config() async def _update_satellite_config(self) -> None: """Get the latest satellite configuration from the device.""" config = await self.cli.get_voice_assistant_configuration(_CONFIG_TIMEOUT_SEC) # Update available/active wake words self._satellite_config.available_wake_words = [ assist_satellite.AssistSatelliteWakeWord( id=model.id, wake_word=model.wake_word, trained_languages=list(model.trained_languages), ) for model in config.available_wake_words ] self._satellite_config.active_wake_words = list(config.active_wake_words) self._satellite_config.max_active_wake_words = config.max_active_wake_words _LOGGER.debug("Received satellite configuration: %s", self._satellite_config) async def async_added_to_hass(self) -> None: """Run when entity about to be added to hass.""" await super().async_added_to_hass() assert self.entry_data.device_info is not None feature_flags = ( self.entry_data.device_info.voice_assistant_feature_flags_compat( self.entry_data.api_version ) ) if feature_flags & VoiceAssistantFeature.API_AUDIO: # TCP audio self.entry_data.disconnect_callbacks.add( self.cli.subscribe_voice_assistant( handle_start=self.handle_pipeline_start, handle_stop=self.handle_pipeline_stop, handle_audio=self.handle_audio, handle_announcement_finished=self.handle_announcement_finished, ) ) else: # UDP audio self.entry_data.disconnect_callbacks.add( self.cli.subscribe_voice_assistant( handle_start=self.handle_pipeline_start, handle_stop=self.handle_pipeline_stop, handle_announcement_finished=self.handle_announcement_finished, ) ) if feature_flags & VoiceAssistantFeature.TIMERS: # Device supports timers assert (self.registry_entry is not None) and ( self.registry_entry.device_id is not None ) self.entry_data.disconnect_callbacks.add( async_register_timer_handler( self.hass, self.registry_entry.device_id, self.handle_timer_event ) ) if feature_flags & VoiceAssistantFeature.ANNOUNCE: # Device supports announcements self._attr_supported_features |= ( assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE ) if not (feature_flags & VoiceAssistantFeature.SPEAKER): # Will use media player for TTS/announcements self._update_tts_format() # Fetch latest config in the background self.config_entry.async_create_background_task( self.hass, self._update_satellite_config(), "esphome_voice_assistant_config" ) async def async_will_remove_from_hass(self) -> None: """Run when entity will be removed from hass.""" await super().async_will_remove_from_hass() self._is_running = False self._stop_pipeline() def on_pipeline_event(self, event: PipelineEvent) -> None: """Handle pipeline events.""" try: event_type = _VOICE_ASSISTANT_EVENT_TYPES.from_hass(event.type) except KeyError: _LOGGER.debug("Received unknown pipeline event type: %s", event.type) return data_to_send: dict[str, Any] = {} if event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_START: self.entry_data.async_set_assist_pipeline_state(True) elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_STT_END: assert event.data is not None data_to_send = {"text": event.data["stt_output"]["text"]} elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END: assert event.data is not None data_to_send = { "conversation_id": event.data["intent_output"]["conversation_id"] or "", } elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START: assert event.data is not None data_to_send = {"text": event.data["tts_input"]} elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END: assert event.data is not None if tts_output := event.data["tts_output"]: path = tts_output["url"] url = async_process_play_media_url(self.hass, path) data_to_send = {"url": url} assert self.entry_data.device_info is not None feature_flags = ( self.entry_data.device_info.voice_assistant_feature_flags_compat( self.entry_data.api_version ) ) if feature_flags & VoiceAssistantFeature.SPEAKER: media_id = tts_output["media_id"] self._tts_streaming_task = ( self.config_entry.async_create_background_task( self.hass, self._stream_tts_audio(media_id), "esphome_voice_assistant_tts", ) ) elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END: assert event.data is not None if not event.data["wake_word_output"]: event_type = VoiceAssistantEventType.VOICE_ASSISTANT_ERROR data_to_send = { "code": "no_wake_word", "message": "No wake word detected", } elif event_type == VoiceAssistantEventType.VOICE_ASSISTANT_ERROR: assert event.data is not None data_to_send = { "code": event.data["code"], "message": event.data["message"], } self.cli.send_voice_assistant_event(event_type, data_to_send) async def async_announce( self, announcement: assist_satellite.AssistSatelliteAnnouncement ) -> None: """Announce media on the satellite. Should block until the announcement is done playing. """ _LOGGER.debug( "Waiting for announcement to finished (message=%s, media_id=%s)", announcement.message, announcement.media_id, ) media_id = announcement.media_id if announcement.media_id_source != "tts": # Route non-TTS media through the proxy format_to_use: MediaPlayerSupportedFormat | None = None for supported_format in chain( *self.entry_data.media_player_formats.values() ): if supported_format.purpose == MediaPlayerFormatPurpose.ANNOUNCEMENT: format_to_use = supported_format break if format_to_use is not None: assert (self.registry_entry is not None) and ( self.registry_entry.device_id is not None ) proxy_url = async_create_proxy_url( self.hass, self.registry_entry.device_id, media_id, media_format=format_to_use.format, rate=format_to_use.sample_rate or None, channels=format_to_use.num_channels or None, width=format_to_use.sample_bytes or None, ) media_id = async_process_play_media_url(self.hass, proxy_url) await self.cli.send_voice_assistant_announcement_await_response( media_id, _ANNOUNCEMENT_TIMEOUT_SEC, announcement.message ) async def handle_pipeline_start( self, conversation_id: str, flags: int, audio_settings: VoiceAssistantAudioSettings, wake_word_phrase: str | None, ) -> int | None: """Handle pipeline run request.""" # Clear audio queue while not self._audio_queue.empty(): await self._audio_queue.get() if self._tts_streaming_task is not None: # Cancel current TTS response self._tts_streaming_task.cancel() self._tts_streaming_task = None # API or UDP output audio port: int = 0 assert self.entry_data.device_info is not None feature_flags = ( self.entry_data.device_info.voice_assistant_feature_flags_compat( self.entry_data.api_version ) ) if (feature_flags & VoiceAssistantFeature.SPEAKER) and not ( feature_flags & VoiceAssistantFeature.API_AUDIO ): port = await self._start_udp_server() _LOGGER.debug("Started UDP server on port %s", port) # Device triggered pipeline (wake word, etc.) if flags & VoiceAssistantCommandFlag.USE_WAKE_WORD: start_stage = PipelineStage.WAKE_WORD else: start_stage = PipelineStage.STT end_stage = PipelineStage.TTS if feature_flags & VoiceAssistantFeature.SPEAKER: # Stream WAV audio self._attr_tts_options = { tts.ATTR_PREFERRED_FORMAT: "wav", tts.ATTR_PREFERRED_SAMPLE_RATE: 16000, tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1, tts.ATTR_PREFERRED_SAMPLE_BYTES: 2, } else: # ANNOUNCEMENT format from media player self._update_tts_format() # Run the pipeline _LOGGER.debug("Running pipeline from %s to %s", start_stage, end_stage) self.entry_data.async_set_assist_pipeline_state(True) self._pipeline_task = self.config_entry.async_create_background_task( self.hass, self.async_accept_pipeline_from_satellite( audio_stream=self._wrap_audio_stream(), start_stage=start_stage, end_stage=end_stage, wake_word_phrase=wake_word_phrase, ), "esphome_assist_satellite_pipeline", ) self._pipeline_task.add_done_callback( lambda _future: self.handle_pipeline_finished() ) return port async def handle_audio(self, data: bytes) -> None: """Handle incoming audio chunk from API.""" self._audio_queue.put_nowait(data) async def handle_pipeline_stop(self, abort: bool) -> None: """Handle request for pipeline to stop.""" if abort: self._abort_pipeline() else: self._stop_pipeline() def handle_pipeline_finished(self) -> None: """Handle when pipeline has finished running.""" self.entry_data.async_set_assist_pipeline_state(False) self._stop_udp_server() _LOGGER.debug("Pipeline finished") def handle_timer_event( self, event_type: TimerEventType, timer_info: TimerInfo ) -> None: """Handle timer events.""" try: native_event_type = _TIMER_EVENT_TYPES.from_hass(event_type) except KeyError: _LOGGER.debug("Received unknown timer event type: %s", event_type) return self.cli.send_voice_assistant_timer_event( native_event_type, timer_info.id, timer_info.name, timer_info.created_seconds, timer_info.seconds_left, timer_info.is_active, ) async def handle_announcement_finished( self, announce_finished: VoiceAssistantAnnounceFinished ) -> None: """Handle announcement finished message (also sent for TTS).""" self.tts_response_finished() def _update_tts_format(self) -> None: """Update the TTS format from the first media player.""" for supported_format in chain(*self.entry_data.media_player_formats.values()): # Find first announcement format if supported_format.purpose == MediaPlayerFormatPurpose.ANNOUNCEMENT: self._attr_tts_options = { tts.ATTR_PREFERRED_FORMAT: supported_format.format, } if supported_format.sample_rate > 0: self._attr_tts_options[tts.ATTR_PREFERRED_SAMPLE_RATE] = ( supported_format.sample_rate ) if supported_format.sample_rate > 0: self._attr_tts_options[tts.ATTR_PREFERRED_SAMPLE_CHANNELS] = ( supported_format.num_channels ) if supported_format.sample_rate > 0: self._attr_tts_options[tts.ATTR_PREFERRED_SAMPLE_BYTES] = ( supported_format.sample_bytes ) break async def _stream_tts_audio( self, media_id: str, sample_rate: int = 16000, sample_width: int = 2, sample_channels: int = 1, samples_per_chunk: int = 512, ) -> None: """Stream TTS audio chunks to device via API or UDP.""" self.cli.send_voice_assistant_event( VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {} ) try: if not self._is_running: return extension, data = await tts.async_get_media_source_audio( self.hass, media_id, ) if extension != "wav": _LOGGER.error("Only WAV audio can be streamed, got %s", extension) return with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file: if ( (wav_file.getframerate() != sample_rate) or (wav_file.getsampwidth() != sample_width) or (wav_file.getnchannels() != sample_channels) ): _LOGGER.error("Can only stream 16Khz 16-bit mono WAV") return _LOGGER.debug("Streaming %s audio samples", wav_file.getnframes()) while self._is_running: chunk = wav_file.readframes(samples_per_chunk) if not chunk: break if self._udp_server is not None: self._udp_server.send_audio_bytes(chunk) else: self.cli.send_voice_assistant_audio(chunk) # Wait for 90% of the duration of the audio that was # sent for it to be played. This will overrun the # device's buffer for very long audio, so using a media # player is preferred. samples_in_chunk = len(chunk) // (sample_width * sample_channels) seconds_in_chunk = samples_in_chunk / sample_rate await asyncio.sleep(seconds_in_chunk * 0.9) except asyncio.CancelledError: return # Don't trigger state change finally: self.cli.send_voice_assistant_event( VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END, {} ) # State change self.tts_response_finished() async def _wrap_audio_stream(self) -> AsyncIterable[bytes]: """Yield audio chunks from the queue until None.""" while True: chunk = await self._audio_queue.get() if not chunk: break yield chunk def _stop_pipeline(self) -> None: """Request pipeline to be stopped by ending the audio stream and continue processing.""" self._audio_queue.put_nowait(None) _LOGGER.debug("Requested pipeline stop") def _abort_pipeline(self) -> None: """Request pipeline to be aborted (no further processing).""" _LOGGER.debug("Requested pipeline abort") self._audio_queue.put_nowait(None) if self._pipeline_task is not None: self._pipeline_task.cancel() async def _start_udp_server(self) -> int: """Start a UDP server on a random free port.""" sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.setblocking(False) sock.bind(("", 0)) # random free port ( _transport, protocol, ) = await asyncio.get_running_loop().create_datagram_endpoint( partial(VoiceAssistantUDPServer, self._audio_queue), sock=sock ) assert isinstance(protocol, VoiceAssistantUDPServer) self._udp_server = protocol # Return port return cast(int, sock.getsockname()[1]) def _stop_udp_server(self) -> None: """Stop the UDP server if it's running.""" if self._udp_server is None: return try: self._udp_server.close() finally: self._udp_server = None _LOGGER.debug("Stopped UDP server") class VoiceAssistantUDPServer(asyncio.DatagramProtocol): """Receive UDP packets and forward them to the audio queue.""" transport: asyncio.DatagramTransport | None = None remote_addr: tuple[str, int] | None = None def __init__( self, audio_queue: asyncio.Queue[bytes | None], *args: Any, **kwargs: Any ) -> None: """Initialize protocol.""" super().__init__(*args, **kwargs) self._audio_queue = audio_queue def connection_made(self, transport: asyncio.BaseTransport) -> None: """Store transport for later use.""" self.transport = cast(asyncio.DatagramTransport, transport) def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: """Handle incoming UDP packet.""" if self.remote_addr is None: self.remote_addr = addr self._audio_queue.put_nowait(data) def error_received(self, exc: Exception) -> None: """Handle when a send or receive operation raises an OSError. (Other than BlockingIOError or InterruptedError.) """ _LOGGER.error("ESPHome Voice Assistant UDP server error received: %s", exc) # Stop pipeline self._audio_queue.put_nowait(None) def close(self) -> None: """Close the receiver.""" if self.transport is not None: self.transport.close() self.remote_addr = None def send_audio_bytes(self, data: bytes) -> None: """Send bytes to the device via UDP.""" if self.transport is None: _LOGGER.error("No transport to send audio to") return if self.remote_addr is None: _LOGGER.error("No address to send audio to") return self.transport.sendto(data, self.remote_addr)