Add wake word select for ESPHome Assist satellite (#131309)

* Add wake word select

* Fix linting

* Move to ESPHome

* Clean up and add more tests

* Update homeassistant/components/esphome/select.py

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
pull/131683/head
Michael Hansen 2024-11-26 21:59:49 -06:00 committed by GitHub
parent a97eeaf189
commit 46fe3dcbf1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 352 additions and 10 deletions

View File

@ -95,11 +95,7 @@ async def async_setup_entry(
if entry_data.device_info.voice_assistant_feature_flags_compat(
entry_data.api_version
):
async_add_entities(
[
EsphomeAssistSatellite(entry, entry_data),
]
)
async_add_entities([EsphomeAssistSatellite(entry, entry_data)])
class EsphomeAssistSatellite(
@ -198,6 +194,9 @@ class EsphomeAssistSatellite(
self._satellite_config.max_active_wake_words = config.max_active_wake_words
_LOGGER.debug("Received satellite configuration: %s", self._satellite_config)
# Inform listeners that config has been updated
self.entry_data.async_assist_satellite_config_updated(self._satellite_config)
async def async_added_to_hass(self) -> None:
"""Run when entity about to be added to hass."""
await super().async_added_to_hass()
@ -254,6 +253,13 @@ class EsphomeAssistSatellite(
# Will use media player for TTS/announcements
self._update_tts_format()
# Update wake word select when config is updated
self.async_on_remove(
self.entry_data.async_register_assist_satellite_set_wake_word_callback(
self.async_set_wake_word
)
)
async def async_will_remove_from_hass(self) -> None:
"""Run when entity will be removed from hass."""
await super().async_will_remove_from_hass()
@ -478,6 +484,17 @@ class EsphomeAssistSatellite(
"""Handle announcement finished message (also sent for TTS)."""
self.tts_response_finished()
@callback
def async_set_wake_word(self, wake_word_id: str) -> None:
"""Set active wake word and update config on satellite."""
self._satellite_config.active_wake_words = [wake_word_id]
self.config_entry.async_create_background_task(
self.hass,
self.async_set_configuration(self._satellite_config),
"esphome_voice_assistant_set_config",
)
_LOGGER.debug("Setting active wake word: %s", wake_word_id)
def _update_tts_format(self) -> None:
"""Update the TTS format from the first media player."""
for supported_format in chain(*self.entry_data.media_player_formats.values()):

View File

@ -48,6 +48,7 @@ from aioesphomeapi import (
from aioesphomeapi.model import ButtonInfo
from bleak_esphome.backend.device import ESPHomeBluetoothDevice
from homeassistant.components.assist_satellite import AssistSatelliteConfiguration
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
@ -152,6 +153,12 @@ class RuntimeEntryData:
media_player_formats: dict[str, list[MediaPlayerSupportedFormat]] = field(
default_factory=lambda: defaultdict(list)
)
assist_satellite_config_update_callbacks: list[
Callable[[AssistSatelliteConfiguration], None]
] = field(default_factory=list)
assist_satellite_set_wake_word_callbacks: list[Callable[[str], None]] = field(
default_factory=list
)
@property
def name(self) -> str:
@ -504,3 +511,35 @@ class RuntimeEntryData:
# We use this to determine if a deep sleep device should
# be marked as unavailable or not.
self.expected_disconnect = True
@callback
def async_register_assist_satellite_config_updated_callback(
self,
callback_: Callable[[AssistSatelliteConfiguration], None],
) -> CALLBACK_TYPE:
"""Register to receive callbacks when the Assist satellite's configuration is updated."""
self.assist_satellite_config_update_callbacks.append(callback_)
return lambda: self.assist_satellite_config_update_callbacks.remove(callback_)
@callback
def async_assist_satellite_config_updated(
self, config: AssistSatelliteConfiguration
) -> None:
"""Notify listeners that the Assist satellite configuration has been updated."""
for callback_ in self.assist_satellite_config_update_callbacks.copy():
callback_(config)
@callback
def async_register_assist_satellite_set_wake_word_callback(
self,
callback_: Callable[[str], None],
) -> CALLBACK_TYPE:
"""Register to receive callbacks when the Assist satellite's wake word is set."""
self.assist_satellite_set_wake_word_callbacks.append(callback_)
return lambda: self.assist_satellite_set_wake_word_callbacks.remove(callback_)
@callback
def async_assist_satellite_set_wake_word(self, wake_word_id: str) -> None:
"""Notify listeners that the Assist satellite wake word has been set."""
for callback_ in self.assist_satellite_set_wake_word_callbacks.copy():
callback_(wake_word_id)

View File

@ -8,8 +8,10 @@ from homeassistant.components.assist_pipeline.select import (
AssistPipelineSelect,
VadSensitivitySelect,
)
from homeassistant.components.select import SelectEntity
from homeassistant.components.assist_satellite import AssistSatelliteConfiguration
from homeassistant.components.select import SelectEntity, SelectEntityDescription
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import restore_state
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
@ -47,6 +49,7 @@ async def async_setup_entry(
[
EsphomeAssistPipelineSelect(hass, entry_data),
EsphomeVadSensitivitySelect(hass, entry_data),
EsphomeAssistSatelliteWakeWordSelect(hass, entry_data),
]
)
@ -89,3 +92,75 @@ class EsphomeVadSensitivitySelect(EsphomeAssistEntity, VadSensitivitySelect):
"""Initialize a VAD sensitivity selector."""
EsphomeAssistEntity.__init__(self, entry_data)
VadSensitivitySelect.__init__(self, hass, self._device_info.mac_address)
class EsphomeAssistSatelliteWakeWordSelect(
EsphomeAssistEntity, SelectEntity, restore_state.RestoreEntity
):
"""Wake word selector for esphome devices."""
entity_description = SelectEntityDescription(
key="wake_word", translation_key="wake_word"
)
_attr_should_poll = False
_attr_current_option: str | None = None
_attr_options: list[str] = []
def __init__(self, hass: HomeAssistant, entry_data: RuntimeEntryData) -> None:
"""Initialize a wake word selector."""
EsphomeAssistEntity.__init__(self, entry_data)
unique_id_prefix = self._device_info.mac_address
self._attr_unique_id = f"{unique_id_prefix}-wake_word"
# name -> id
self._wake_words: dict[str, str] = {}
@property
def available(self) -> bool:
"""Return if entity is available."""
return bool(self._attr_options)
async def async_added_to_hass(self) -> None:
"""Run when entity about to be added to hass."""
await super().async_added_to_hass()
# Update options when config is updated
self.async_on_remove(
self._entry_data.async_register_assist_satellite_config_updated_callback(
self.async_satellite_config_updated
)
)
async def async_select_option(self, option: str) -> None:
"""Select an option."""
if wake_word_id := self._wake_words.get(option):
# _attr_current_option will be updated on
# async_satellite_config_updated after the device sets the wake
# word.
self._entry_data.async_assist_satellite_set_wake_word(wake_word_id)
def async_satellite_config_updated(
self, config: AssistSatelliteConfiguration
) -> None:
"""Update options with available wake words."""
if (not config.available_wake_words) or (config.max_active_wake_words < 1):
self._attr_current_option = None
self._wake_words.clear()
self.async_write_ha_state()
return
self._wake_words = {w.wake_word: w.id for w in config.available_wake_words}
self._attr_options = sorted(self._wake_words)
if config.active_wake_words:
# Select first active wake word
wake_word_id = config.active_wake_words[0]
for wake_word in config.available_wake_words:
if wake_word.id == wake_word_id:
self._attr_current_option = wake_word.wake_word
else:
# Select first available wake word
self._attr_current_option = config.available_wake_words[0].wake_word
self.async_write_ha_state()

View File

@ -84,6 +84,12 @@
"aggressive": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::state::aggressive%]",
"relaxed": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::state::relaxed%]"
}
},
"wake_word": {
"name": "Wake word",
"state": {
"okay_nabu": "Okay Nabu"
}
}
},
"climate": {

View File

@ -184,7 +184,7 @@ async def test_select_entity_changing_vad_sensitivity(
hass: HomeAssistant,
init_select: MockConfigEntry,
) -> None:
"""Test entity tracking pipeline changes."""
"""Test entity tracking vad sensitivity changes."""
config_entry = init_select # nicer naming
config_entry.mock_state(hass, ConfigEntryState.LOADED)
@ -192,7 +192,7 @@ async def test_select_entity_changing_vad_sensitivity(
assert state is not None
assert state.state == VadSensitivity.DEFAULT.value
# Change select to new pipeline
# Change select to new sensitivity
await hass.services.async_call(
"select",
"select_option",

View File

@ -5,7 +5,7 @@ from collections.abc import Awaitable, Callable
from dataclasses import replace
import io
import socket
from unittest.mock import ANY, Mock, patch
from unittest.mock import ANY, AsyncMock, Mock, patch
import wave
from aioesphomeapi import (
@ -42,6 +42,10 @@ from homeassistant.components.esphome.assist_satellite import (
VoiceAssistantUDPServer,
)
from homeassistant.components.media_source import PlayMedia
from homeassistant.components.select import (
DOMAIN as SELECT_DOMAIN,
SERVICE_SELECT_OPTION,
)
from homeassistant.const import STATE_UNAVAILABLE, Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er, intent as intent_helper
@ -1473,3 +1477,194 @@ async def test_get_set_configuration(
# Device should have been updated
assert satellite.async_get_configuration() == updated_config
async def test_wake_word_select(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test wake word select."""
device_config = AssistSatelliteConfiguration(
available_wake_words=[
AssistSatelliteWakeWord("okay_nabu", "Okay Nabu", ["en"]),
AssistSatelliteWakeWord("hey_jarvis", "Hey Jarvis", ["en"]),
AssistSatelliteWakeWord("hey_mycroft", "Hey Mycroft", ["en"]),
],
active_wake_words=["hey_jarvis"],
max_active_wake_words=1,
)
mock_client.get_voice_assistant_configuration.return_value = device_config
# Wrap mock so we can tell when it's done
configuration_set = asyncio.Event()
async def wrapper(*args, **kwargs):
# Update device config because entity will request it after update
device_config.active_wake_words = kwargs["active_wake_words"]
configuration_set.set()
mock_client.set_voice_assistant_configuration = AsyncMock(side_effect=wrapper)
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.ANNOUNCE
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
assert satellite.async_get_configuration().active_wake_words == ["hey_jarvis"]
# Active wake word should be selected
state = hass.states.get("select.test_wake_word")
assert state is not None
assert state.state == "Hey Jarvis"
# Changing the select should set the active wake word
await hass.services.async_call(
SELECT_DOMAIN,
SERVICE_SELECT_OPTION,
{"entity_id": "select.test_wake_word", "option": "Okay Nabu"},
blocking=True,
)
await hass.async_block_till_done()
state = hass.states.get("select.test_wake_word")
assert state is not None
assert state.state == "Okay Nabu"
# Wait for device config to be updated
async with asyncio.timeout(1):
await configuration_set.wait()
# Satellite config should have been updated
assert satellite.async_get_configuration().active_wake_words == ["okay_nabu"]
async def test_wake_word_select_no_wake_words(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test wake word select is unavailable when there are no available wake word."""
device_config = AssistSatelliteConfiguration(
available_wake_words=[],
active_wake_words=[],
max_active_wake_words=1,
)
mock_client.get_voice_assistant_configuration.return_value = device_config
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.ANNOUNCE
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
assert not satellite.async_get_configuration().available_wake_words
# Select should be unavailable
state = hass.states.get("select.test_wake_word")
assert state is not None
assert state.state == STATE_UNAVAILABLE
async def test_wake_word_select_zero_max_wake_words(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test wake word select is unavailable max wake words is zero."""
device_config = AssistSatelliteConfiguration(
available_wake_words=[
AssistSatelliteWakeWord("okay_nabu", "Okay Nabu", ["en"]),
],
active_wake_words=[],
max_active_wake_words=0,
)
mock_client.get_voice_assistant_configuration.return_value = device_config
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.ANNOUNCE
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
assert satellite.async_get_configuration().max_active_wake_words == 0
# Select should be unavailable
state = hass.states.get("select.test_wake_word")
assert state is not None
assert state.state == STATE_UNAVAILABLE
async def test_wake_word_select_no_active_wake_words(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test wake word select uses first available wake word if none are active."""
device_config = AssistSatelliteConfiguration(
available_wake_words=[
AssistSatelliteWakeWord("okay_nabu", "Okay Nabu", ["en"]),
AssistSatelliteWakeWord("hey_jarvis", "Hey Jarvis", ["en"]),
],
active_wake_words=[],
max_active_wake_words=1,
)
mock_client.get_voice_assistant_configuration.return_value = device_config
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.ANNOUNCE
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
assert not satellite.async_get_configuration().active_wake_words
# First available wake word should be selected
state = hass.states.get("select.test_wake_word")
assert state is not None
assert state.state == "Okay Nabu"

View File

@ -9,7 +9,7 @@ from homeassistant.components.select import (
DOMAIN as SELECT_DOMAIN,
SERVICE_SELECT_OPTION,
)
from homeassistant.const import ATTR_ENTITY_ID
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNAVAILABLE
from homeassistant.core import HomeAssistant
@ -38,6 +38,16 @@ async def test_vad_sensitivity_select(
assert state.state == "default"
async def test_wake_word_select(
hass: HomeAssistant,
mock_voice_assistant_v1_entry,
) -> None:
"""Test that wake word select is unavailable initially."""
state = hass.states.get("select.test_wake_word")
assert state is not None
assert state.state == STATE_UNAVAILABLE
async def test_select_generic_entity(
hass: HomeAssistant, mock_client: APIClient, mock_generic_device_entry
) -> None: