From 46fe3dcbf1c3e7c7431e170afbcde8eff6eb1011 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Tue, 26 Nov 2024 21:59:49 -0600 Subject: [PATCH] Add wake word select for ESPHome Assist satellite (#131309) * Add wake word select * Fix linting * Move to ESPHome * Clean up and add more tests * Update homeassistant/components/esphome/select.py --------- Co-authored-by: Paulus Schoutsen --- .../components/esphome/assist_satellite.py | 27 ++- .../components/esphome/entry_data.py | 39 ++++ homeassistant/components/esphome/select.py | 77 ++++++- homeassistant/components/esphome/strings.json | 6 + .../components/assist_pipeline/test_select.py | 4 +- .../esphome/test_assist_satellite.py | 197 +++++++++++++++++- tests/components/esphome/test_select.py | 12 +- 7 files changed, 352 insertions(+), 10 deletions(-) diff --git a/homeassistant/components/esphome/assist_satellite.py b/homeassistant/components/esphome/assist_satellite.py index dc513a03e02..f60668b0a06 100644 --- a/homeassistant/components/esphome/assist_satellite.py +++ b/homeassistant/components/esphome/assist_satellite.py @@ -95,11 +95,7 @@ async def async_setup_entry( if entry_data.device_info.voice_assistant_feature_flags_compat( entry_data.api_version ): - async_add_entities( - [ - EsphomeAssistSatellite(entry, entry_data), - ] - ) + async_add_entities([EsphomeAssistSatellite(entry, entry_data)]) class EsphomeAssistSatellite( @@ -198,6 +194,9 @@ class EsphomeAssistSatellite( self._satellite_config.max_active_wake_words = config.max_active_wake_words _LOGGER.debug("Received satellite configuration: %s", self._satellite_config) + # Inform listeners that config has been updated + self.entry_data.async_assist_satellite_config_updated(self._satellite_config) + async def async_added_to_hass(self) -> None: """Run when entity about to be added to hass.""" await super().async_added_to_hass() @@ -254,6 +253,13 @@ class EsphomeAssistSatellite( # Will use media player for TTS/announcements self._update_tts_format() + # Update wake word select when config is updated + self.async_on_remove( + self.entry_data.async_register_assist_satellite_set_wake_word_callback( + self.async_set_wake_word + ) + ) + async def async_will_remove_from_hass(self) -> None: """Run when entity will be removed from hass.""" await super().async_will_remove_from_hass() @@ -478,6 +484,17 @@ class EsphomeAssistSatellite( """Handle announcement finished message (also sent for TTS).""" self.tts_response_finished() + @callback + def async_set_wake_word(self, wake_word_id: str) -> None: + """Set active wake word and update config on satellite.""" + self._satellite_config.active_wake_words = [wake_word_id] + self.config_entry.async_create_background_task( + self.hass, + self.async_set_configuration(self._satellite_config), + "esphome_voice_assistant_set_config", + ) + _LOGGER.debug("Setting active wake word: %s", wake_word_id) + def _update_tts_format(self) -> None: """Update the TTS format from the first media player.""" for supported_format in chain(*self.entry_data.media_player_formats.values()): diff --git a/homeassistant/components/esphome/entry_data.py b/homeassistant/components/esphome/entry_data.py index f1b5218eec7..fc41ee99a00 100644 --- a/homeassistant/components/esphome/entry_data.py +++ b/homeassistant/components/esphome/entry_data.py @@ -48,6 +48,7 @@ from aioesphomeapi import ( from aioesphomeapi.model import ButtonInfo from bleak_esphome.backend.device import ESPHomeBluetoothDevice +from homeassistant.components.assist_satellite import AssistSatelliteConfiguration from homeassistant.config_entries import ConfigEntry from homeassistant.const import Platform from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback @@ -152,6 +153,12 @@ class RuntimeEntryData: media_player_formats: dict[str, list[MediaPlayerSupportedFormat]] = field( default_factory=lambda: defaultdict(list) ) + assist_satellite_config_update_callbacks: list[ + Callable[[AssistSatelliteConfiguration], None] + ] = field(default_factory=list) + assist_satellite_set_wake_word_callbacks: list[Callable[[str], None]] = field( + default_factory=list + ) @property def name(self) -> str: @@ -504,3 +511,35 @@ class RuntimeEntryData: # We use this to determine if a deep sleep device should # be marked as unavailable or not. self.expected_disconnect = True + + @callback + def async_register_assist_satellite_config_updated_callback( + self, + callback_: Callable[[AssistSatelliteConfiguration], None], + ) -> CALLBACK_TYPE: + """Register to receive callbacks when the Assist satellite's configuration is updated.""" + self.assist_satellite_config_update_callbacks.append(callback_) + return lambda: self.assist_satellite_config_update_callbacks.remove(callback_) + + @callback + def async_assist_satellite_config_updated( + self, config: AssistSatelliteConfiguration + ) -> None: + """Notify listeners that the Assist satellite configuration has been updated.""" + for callback_ in self.assist_satellite_config_update_callbacks.copy(): + callback_(config) + + @callback + def async_register_assist_satellite_set_wake_word_callback( + self, + callback_: Callable[[str], None], + ) -> CALLBACK_TYPE: + """Register to receive callbacks when the Assist satellite's wake word is set.""" + self.assist_satellite_set_wake_word_callbacks.append(callback_) + return lambda: self.assist_satellite_set_wake_word_callbacks.remove(callback_) + + @callback + def async_assist_satellite_set_wake_word(self, wake_word_id: str) -> None: + """Notify listeners that the Assist satellite wake word has been set.""" + for callback_ in self.assist_satellite_set_wake_word_callbacks.copy(): + callback_(wake_word_id) diff --git a/homeassistant/components/esphome/select.py b/homeassistant/components/esphome/select.py index 623946503eb..ab7654478a7 100644 --- a/homeassistant/components/esphome/select.py +++ b/homeassistant/components/esphome/select.py @@ -8,8 +8,10 @@ from homeassistant.components.assist_pipeline.select import ( AssistPipelineSelect, VadSensitivitySelect, ) -from homeassistant.components.select import SelectEntity +from homeassistant.components.assist_satellite import AssistSatelliteConfiguration +from homeassistant.components.select import SelectEntity, SelectEntityDescription from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers import restore_state from homeassistant.helpers.entity_platform import AddEntitiesCallback from .const import DOMAIN @@ -47,6 +49,7 @@ async def async_setup_entry( [ EsphomeAssistPipelineSelect(hass, entry_data), EsphomeVadSensitivitySelect(hass, entry_data), + EsphomeAssistSatelliteWakeWordSelect(hass, entry_data), ] ) @@ -89,3 +92,75 @@ class EsphomeVadSensitivitySelect(EsphomeAssistEntity, VadSensitivitySelect): """Initialize a VAD sensitivity selector.""" EsphomeAssistEntity.__init__(self, entry_data) VadSensitivitySelect.__init__(self, hass, self._device_info.mac_address) + + +class EsphomeAssistSatelliteWakeWordSelect( + EsphomeAssistEntity, SelectEntity, restore_state.RestoreEntity +): + """Wake word selector for esphome devices.""" + + entity_description = SelectEntityDescription( + key="wake_word", translation_key="wake_word" + ) + _attr_should_poll = False + _attr_current_option: str | None = None + _attr_options: list[str] = [] + + def __init__(self, hass: HomeAssistant, entry_data: RuntimeEntryData) -> None: + """Initialize a wake word selector.""" + EsphomeAssistEntity.__init__(self, entry_data) + + unique_id_prefix = self._device_info.mac_address + self._attr_unique_id = f"{unique_id_prefix}-wake_word" + + # name -> id + self._wake_words: dict[str, str] = {} + + @property + def available(self) -> bool: + """Return if entity is available.""" + return bool(self._attr_options) + + async def async_added_to_hass(self) -> None: + """Run when entity about to be added to hass.""" + await super().async_added_to_hass() + + # Update options when config is updated + self.async_on_remove( + self._entry_data.async_register_assist_satellite_config_updated_callback( + self.async_satellite_config_updated + ) + ) + + async def async_select_option(self, option: str) -> None: + """Select an option.""" + if wake_word_id := self._wake_words.get(option): + # _attr_current_option will be updated on + # async_satellite_config_updated after the device sets the wake + # word. + self._entry_data.async_assist_satellite_set_wake_word(wake_word_id) + + def async_satellite_config_updated( + self, config: AssistSatelliteConfiguration + ) -> None: + """Update options with available wake words.""" + if (not config.available_wake_words) or (config.max_active_wake_words < 1): + self._attr_current_option = None + self._wake_words.clear() + self.async_write_ha_state() + return + + self._wake_words = {w.wake_word: w.id for w in config.available_wake_words} + self._attr_options = sorted(self._wake_words) + + if config.active_wake_words: + # Select first active wake word + wake_word_id = config.active_wake_words[0] + for wake_word in config.available_wake_words: + if wake_word.id == wake_word_id: + self._attr_current_option = wake_word.wake_word + else: + # Select first available wake word + self._attr_current_option = config.available_wake_words[0].wake_word + + self.async_write_ha_state() diff --git a/homeassistant/components/esphome/strings.json b/homeassistant/components/esphome/strings.json index 971a489a9e2..81b58de8df2 100644 --- a/homeassistant/components/esphome/strings.json +++ b/homeassistant/components/esphome/strings.json @@ -84,6 +84,12 @@ "aggressive": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::state::aggressive%]", "relaxed": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::state::relaxed%]" } + }, + "wake_word": { + "name": "Wake word", + "state": { + "okay_nabu": "Okay Nabu" + } } }, "climate": { diff --git a/tests/components/assist_pipeline/test_select.py b/tests/components/assist_pipeline/test_select.py index 9fb02e228d8..5ce3b1020d0 100644 --- a/tests/components/assist_pipeline/test_select.py +++ b/tests/components/assist_pipeline/test_select.py @@ -184,7 +184,7 @@ async def test_select_entity_changing_vad_sensitivity( hass: HomeAssistant, init_select: MockConfigEntry, ) -> None: - """Test entity tracking pipeline changes.""" + """Test entity tracking vad sensitivity changes.""" config_entry = init_select # nicer naming config_entry.mock_state(hass, ConfigEntryState.LOADED) @@ -192,7 +192,7 @@ async def test_select_entity_changing_vad_sensitivity( assert state is not None assert state.state == VadSensitivity.DEFAULT.value - # Change select to new pipeline + # Change select to new sensitivity await hass.services.async_call( "select", "select_option", diff --git a/tests/components/esphome/test_assist_satellite.py b/tests/components/esphome/test_assist_satellite.py index e8344e50161..5ca333df1e2 100644 --- a/tests/components/esphome/test_assist_satellite.py +++ b/tests/components/esphome/test_assist_satellite.py @@ -5,7 +5,7 @@ from collections.abc import Awaitable, Callable from dataclasses import replace import io import socket -from unittest.mock import ANY, Mock, patch +from unittest.mock import ANY, AsyncMock, Mock, patch import wave from aioesphomeapi import ( @@ -42,6 +42,10 @@ from homeassistant.components.esphome.assist_satellite import ( VoiceAssistantUDPServer, ) from homeassistant.components.media_source import PlayMedia +from homeassistant.components.select import ( + DOMAIN as SELECT_DOMAIN, + SERVICE_SELECT_OPTION, +) from homeassistant.const import STATE_UNAVAILABLE, Platform from homeassistant.core import HomeAssistant from homeassistant.helpers import entity_registry as er, intent as intent_helper @@ -1473,3 +1477,194 @@ async def test_get_set_configuration( # Device should have been updated assert satellite.async_get_configuration() == updated_config + + +async def test_wake_word_select( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], +) -> None: + """Test wake word select.""" + device_config = AssistSatelliteConfiguration( + available_wake_words=[ + AssistSatelliteWakeWord("okay_nabu", "Okay Nabu", ["en"]), + AssistSatelliteWakeWord("hey_jarvis", "Hey Jarvis", ["en"]), + AssistSatelliteWakeWord("hey_mycroft", "Hey Mycroft", ["en"]), + ], + active_wake_words=["hey_jarvis"], + max_active_wake_words=1, + ) + mock_client.get_voice_assistant_configuration.return_value = device_config + + # Wrap mock so we can tell when it's done + configuration_set = asyncio.Event() + + async def wrapper(*args, **kwargs): + # Update device config because entity will request it after update + device_config.active_wake_words = kwargs["active_wake_words"] + configuration_set.set() + + mock_client.set_voice_assistant_configuration = AsyncMock(side_effect=wrapper) + + mock_device: MockESPHomeDevice = await mock_esphome_device( + mock_client=mock_client, + entity_info=[], + user_service=[], + states=[], + device_info={ + "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT + | VoiceAssistantFeature.ANNOUNCE + }, + ) + await hass.async_block_till_done() + + satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) + assert satellite is not None + assert satellite.async_get_configuration().active_wake_words == ["hey_jarvis"] + + # Active wake word should be selected + state = hass.states.get("select.test_wake_word") + assert state is not None + assert state.state == "Hey Jarvis" + + # Changing the select should set the active wake word + await hass.services.async_call( + SELECT_DOMAIN, + SERVICE_SELECT_OPTION, + {"entity_id": "select.test_wake_word", "option": "Okay Nabu"}, + blocking=True, + ) + await hass.async_block_till_done() + + state = hass.states.get("select.test_wake_word") + assert state is not None + assert state.state == "Okay Nabu" + + # Wait for device config to be updated + async with asyncio.timeout(1): + await configuration_set.wait() + + # Satellite config should have been updated + assert satellite.async_get_configuration().active_wake_words == ["okay_nabu"] + + +async def test_wake_word_select_no_wake_words( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], +) -> None: + """Test wake word select is unavailable when there are no available wake word.""" + device_config = AssistSatelliteConfiguration( + available_wake_words=[], + active_wake_words=[], + max_active_wake_words=1, + ) + mock_client.get_voice_assistant_configuration.return_value = device_config + + mock_device: MockESPHomeDevice = await mock_esphome_device( + mock_client=mock_client, + entity_info=[], + user_service=[], + states=[], + device_info={ + "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT + | VoiceAssistantFeature.ANNOUNCE + }, + ) + await hass.async_block_till_done() + + satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) + assert satellite is not None + assert not satellite.async_get_configuration().available_wake_words + + # Select should be unavailable + state = hass.states.get("select.test_wake_word") + assert state is not None + assert state.state == STATE_UNAVAILABLE + + +async def test_wake_word_select_zero_max_wake_words( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], +) -> None: + """Test wake word select is unavailable max wake words is zero.""" + device_config = AssistSatelliteConfiguration( + available_wake_words=[ + AssistSatelliteWakeWord("okay_nabu", "Okay Nabu", ["en"]), + ], + active_wake_words=[], + max_active_wake_words=0, + ) + mock_client.get_voice_assistant_configuration.return_value = device_config + + mock_device: MockESPHomeDevice = await mock_esphome_device( + mock_client=mock_client, + entity_info=[], + user_service=[], + states=[], + device_info={ + "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT + | VoiceAssistantFeature.ANNOUNCE + }, + ) + await hass.async_block_till_done() + + satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) + assert satellite is not None + assert satellite.async_get_configuration().max_active_wake_words == 0 + + # Select should be unavailable + state = hass.states.get("select.test_wake_word") + assert state is not None + assert state.state == STATE_UNAVAILABLE + + +async def test_wake_word_select_no_active_wake_words( + hass: HomeAssistant, + mock_client: APIClient, + mock_esphome_device: Callable[ + [APIClient, list[EntityInfo], list[UserService], list[EntityState]], + Awaitable[MockESPHomeDevice], + ], +) -> None: + """Test wake word select uses first available wake word if none are active.""" + device_config = AssistSatelliteConfiguration( + available_wake_words=[ + AssistSatelliteWakeWord("okay_nabu", "Okay Nabu", ["en"]), + AssistSatelliteWakeWord("hey_jarvis", "Hey Jarvis", ["en"]), + ], + active_wake_words=[], + max_active_wake_words=1, + ) + mock_client.get_voice_assistant_configuration.return_value = device_config + + mock_device: MockESPHomeDevice = await mock_esphome_device( + mock_client=mock_client, + entity_info=[], + user_service=[], + states=[], + device_info={ + "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT + | VoiceAssistantFeature.ANNOUNCE + }, + ) + await hass.async_block_till_done() + + satellite = get_satellite_entity(hass, mock_device.device_info.mac_address) + assert satellite is not None + assert not satellite.async_get_configuration().active_wake_words + + # First available wake word should be selected + state = hass.states.get("select.test_wake_word") + assert state is not None + assert state.state == "Okay Nabu" diff --git a/tests/components/esphome/test_select.py b/tests/components/esphome/test_select.py index fbe30afd042..6ae1260a89d 100644 --- a/tests/components/esphome/test_select.py +++ b/tests/components/esphome/test_select.py @@ -9,7 +9,7 @@ from homeassistant.components.select import ( DOMAIN as SELECT_DOMAIN, SERVICE_SELECT_OPTION, ) -from homeassistant.const import ATTR_ENTITY_ID +from homeassistant.const import ATTR_ENTITY_ID, STATE_UNAVAILABLE from homeassistant.core import HomeAssistant @@ -38,6 +38,16 @@ async def test_vad_sensitivity_select( assert state.state == "default" +async def test_wake_word_select( + hass: HomeAssistant, + mock_voice_assistant_v1_entry, +) -> None: + """Test that wake word select is unavailable initially.""" + state = hass.states.get("select.test_wake_word") + assert state is not None + assert state.state == STATE_UNAVAILABLE + + async def test_select_generic_entity( hass: HomeAssistant, mock_client: APIClient, mock_generic_device_entry ) -> None: