Add wake word select for ESPHome Assist satellite (#131309)
* Add wake word select * Fix linting * Move to ESPHome * Clean up and add more tests * Update homeassistant/components/esphome/select.py --------- Co-authored-by: Paulus Schoutsen <balloob@gmail.com>pull/131683/head
parent
a97eeaf189
commit
46fe3dcbf1
|
@ -95,11 +95,7 @@ async def async_setup_entry(
|
|||
if entry_data.device_info.voice_assistant_feature_flags_compat(
|
||||
entry_data.api_version
|
||||
):
|
||||
async_add_entities(
|
||||
[
|
||||
EsphomeAssistSatellite(entry, entry_data),
|
||||
]
|
||||
)
|
||||
async_add_entities([EsphomeAssistSatellite(entry, entry_data)])
|
||||
|
||||
|
||||
class EsphomeAssistSatellite(
|
||||
|
@ -198,6 +194,9 @@ class EsphomeAssistSatellite(
|
|||
self._satellite_config.max_active_wake_words = config.max_active_wake_words
|
||||
_LOGGER.debug("Received satellite configuration: %s", self._satellite_config)
|
||||
|
||||
# Inform listeners that config has been updated
|
||||
self.entry_data.async_assist_satellite_config_updated(self._satellite_config)
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Run when entity about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
|
@ -254,6 +253,13 @@ class EsphomeAssistSatellite(
|
|||
# Will use media player for TTS/announcements
|
||||
self._update_tts_format()
|
||||
|
||||
# Update wake word select when config is updated
|
||||
self.async_on_remove(
|
||||
self.entry_data.async_register_assist_satellite_set_wake_word_callback(
|
||||
self.async_set_wake_word
|
||||
)
|
||||
)
|
||||
|
||||
async def async_will_remove_from_hass(self) -> None:
|
||||
"""Run when entity will be removed from hass."""
|
||||
await super().async_will_remove_from_hass()
|
||||
|
@ -478,6 +484,17 @@ class EsphomeAssistSatellite(
|
|||
"""Handle announcement finished message (also sent for TTS)."""
|
||||
self.tts_response_finished()
|
||||
|
||||
@callback
|
||||
def async_set_wake_word(self, wake_word_id: str) -> None:
|
||||
"""Set active wake word and update config on satellite."""
|
||||
self._satellite_config.active_wake_words = [wake_word_id]
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self.async_set_configuration(self._satellite_config),
|
||||
"esphome_voice_assistant_set_config",
|
||||
)
|
||||
_LOGGER.debug("Setting active wake word: %s", wake_word_id)
|
||||
|
||||
def _update_tts_format(self) -> None:
|
||||
"""Update the TTS format from the first media player."""
|
||||
for supported_format in chain(*self.entry_data.media_player_formats.values()):
|
||||
|
|
|
@ -48,6 +48,7 @@ from aioesphomeapi import (
|
|||
from aioesphomeapi.model import ButtonInfo
|
||||
from bleak_esphome.backend.device import ESPHomeBluetoothDevice
|
||||
|
||||
from homeassistant.components.assist_satellite import AssistSatelliteConfiguration
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import Platform
|
||||
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
|
||||
|
@ -152,6 +153,12 @@ class RuntimeEntryData:
|
|||
media_player_formats: dict[str, list[MediaPlayerSupportedFormat]] = field(
|
||||
default_factory=lambda: defaultdict(list)
|
||||
)
|
||||
assist_satellite_config_update_callbacks: list[
|
||||
Callable[[AssistSatelliteConfiguration], None]
|
||||
] = field(default_factory=list)
|
||||
assist_satellite_set_wake_word_callbacks: list[Callable[[str], None]] = field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
@ -504,3 +511,35 @@ class RuntimeEntryData:
|
|||
# We use this to determine if a deep sleep device should
|
||||
# be marked as unavailable or not.
|
||||
self.expected_disconnect = True
|
||||
|
||||
@callback
|
||||
def async_register_assist_satellite_config_updated_callback(
|
||||
self,
|
||||
callback_: Callable[[AssistSatelliteConfiguration], None],
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Register to receive callbacks when the Assist satellite's configuration is updated."""
|
||||
self.assist_satellite_config_update_callbacks.append(callback_)
|
||||
return lambda: self.assist_satellite_config_update_callbacks.remove(callback_)
|
||||
|
||||
@callback
|
||||
def async_assist_satellite_config_updated(
|
||||
self, config: AssistSatelliteConfiguration
|
||||
) -> None:
|
||||
"""Notify listeners that the Assist satellite configuration has been updated."""
|
||||
for callback_ in self.assist_satellite_config_update_callbacks.copy():
|
||||
callback_(config)
|
||||
|
||||
@callback
|
||||
def async_register_assist_satellite_set_wake_word_callback(
|
||||
self,
|
||||
callback_: Callable[[str], None],
|
||||
) -> CALLBACK_TYPE:
|
||||
"""Register to receive callbacks when the Assist satellite's wake word is set."""
|
||||
self.assist_satellite_set_wake_word_callbacks.append(callback_)
|
||||
return lambda: self.assist_satellite_set_wake_word_callbacks.remove(callback_)
|
||||
|
||||
@callback
|
||||
def async_assist_satellite_set_wake_word(self, wake_word_id: str) -> None:
|
||||
"""Notify listeners that the Assist satellite wake word has been set."""
|
||||
for callback_ in self.assist_satellite_set_wake_word_callbacks.copy():
|
||||
callback_(wake_word_id)
|
||||
|
|
|
@ -8,8 +8,10 @@ from homeassistant.components.assist_pipeline.select import (
|
|||
AssistPipelineSelect,
|
||||
VadSensitivitySelect,
|
||||
)
|
||||
from homeassistant.components.select import SelectEntity
|
||||
from homeassistant.components.assist_satellite import AssistSatelliteConfiguration
|
||||
from homeassistant.components.select import SelectEntity, SelectEntityDescription
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import restore_state
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .const import DOMAIN
|
||||
|
@ -47,6 +49,7 @@ async def async_setup_entry(
|
|||
[
|
||||
EsphomeAssistPipelineSelect(hass, entry_data),
|
||||
EsphomeVadSensitivitySelect(hass, entry_data),
|
||||
EsphomeAssistSatelliteWakeWordSelect(hass, entry_data),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -89,3 +92,75 @@ class EsphomeVadSensitivitySelect(EsphomeAssistEntity, VadSensitivitySelect):
|
|||
"""Initialize a VAD sensitivity selector."""
|
||||
EsphomeAssistEntity.__init__(self, entry_data)
|
||||
VadSensitivitySelect.__init__(self, hass, self._device_info.mac_address)
|
||||
|
||||
|
||||
class EsphomeAssistSatelliteWakeWordSelect(
|
||||
EsphomeAssistEntity, SelectEntity, restore_state.RestoreEntity
|
||||
):
|
||||
"""Wake word selector for esphome devices."""
|
||||
|
||||
entity_description = SelectEntityDescription(
|
||||
key="wake_word", translation_key="wake_word"
|
||||
)
|
||||
_attr_should_poll = False
|
||||
_attr_current_option: str | None = None
|
||||
_attr_options: list[str] = []
|
||||
|
||||
def __init__(self, hass: HomeAssistant, entry_data: RuntimeEntryData) -> None:
|
||||
"""Initialize a wake word selector."""
|
||||
EsphomeAssistEntity.__init__(self, entry_data)
|
||||
|
||||
unique_id_prefix = self._device_info.mac_address
|
||||
self._attr_unique_id = f"{unique_id_prefix}-wake_word"
|
||||
|
||||
# name -> id
|
||||
self._wake_words: dict[str, str] = {}
|
||||
|
||||
@property
|
||||
def available(self) -> bool:
|
||||
"""Return if entity is available."""
|
||||
return bool(self._attr_options)
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Run when entity about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
|
||||
# Update options when config is updated
|
||||
self.async_on_remove(
|
||||
self._entry_data.async_register_assist_satellite_config_updated_callback(
|
||||
self.async_satellite_config_updated
|
||||
)
|
||||
)
|
||||
|
||||
async def async_select_option(self, option: str) -> None:
|
||||
"""Select an option."""
|
||||
if wake_word_id := self._wake_words.get(option):
|
||||
# _attr_current_option will be updated on
|
||||
# async_satellite_config_updated after the device sets the wake
|
||||
# word.
|
||||
self._entry_data.async_assist_satellite_set_wake_word(wake_word_id)
|
||||
|
||||
def async_satellite_config_updated(
|
||||
self, config: AssistSatelliteConfiguration
|
||||
) -> None:
|
||||
"""Update options with available wake words."""
|
||||
if (not config.available_wake_words) or (config.max_active_wake_words < 1):
|
||||
self._attr_current_option = None
|
||||
self._wake_words.clear()
|
||||
self.async_write_ha_state()
|
||||
return
|
||||
|
||||
self._wake_words = {w.wake_word: w.id for w in config.available_wake_words}
|
||||
self._attr_options = sorted(self._wake_words)
|
||||
|
||||
if config.active_wake_words:
|
||||
# Select first active wake word
|
||||
wake_word_id = config.active_wake_words[0]
|
||||
for wake_word in config.available_wake_words:
|
||||
if wake_word.id == wake_word_id:
|
||||
self._attr_current_option = wake_word.wake_word
|
||||
else:
|
||||
# Select first available wake word
|
||||
self._attr_current_option = config.available_wake_words[0].wake_word
|
||||
|
||||
self.async_write_ha_state()
|
||||
|
|
|
@ -84,6 +84,12 @@
|
|||
"aggressive": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::state::aggressive%]",
|
||||
"relaxed": "[%key:component::assist_pipeline::entity::select::vad_sensitivity::state::relaxed%]"
|
||||
}
|
||||
},
|
||||
"wake_word": {
|
||||
"name": "Wake word",
|
||||
"state": {
|
||||
"okay_nabu": "Okay Nabu"
|
||||
}
|
||||
}
|
||||
},
|
||||
"climate": {
|
||||
|
|
|
@ -184,7 +184,7 @@ async def test_select_entity_changing_vad_sensitivity(
|
|||
hass: HomeAssistant,
|
||||
init_select: MockConfigEntry,
|
||||
) -> None:
|
||||
"""Test entity tracking pipeline changes."""
|
||||
"""Test entity tracking vad sensitivity changes."""
|
||||
config_entry = init_select # nicer naming
|
||||
config_entry.mock_state(hass, ConfigEntryState.LOADED)
|
||||
|
||||
|
@ -192,7 +192,7 @@ async def test_select_entity_changing_vad_sensitivity(
|
|||
assert state is not None
|
||||
assert state.state == VadSensitivity.DEFAULT.value
|
||||
|
||||
# Change select to new pipeline
|
||||
# Change select to new sensitivity
|
||||
await hass.services.async_call(
|
||||
"select",
|
||||
"select_option",
|
||||
|
|
|
@ -5,7 +5,7 @@ from collections.abc import Awaitable, Callable
|
|||
from dataclasses import replace
|
||||
import io
|
||||
import socket
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
from unittest.mock import ANY, AsyncMock, Mock, patch
|
||||
import wave
|
||||
|
||||
from aioesphomeapi import (
|
||||
|
@ -42,6 +42,10 @@ from homeassistant.components.esphome.assist_satellite import (
|
|||
VoiceAssistantUDPServer,
|
||||
)
|
||||
from homeassistant.components.media_source import PlayMedia
|
||||
from homeassistant.components.select import (
|
||||
DOMAIN as SELECT_DOMAIN,
|
||||
SERVICE_SELECT_OPTION,
|
||||
)
|
||||
from homeassistant.const import STATE_UNAVAILABLE, Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import entity_registry as er, intent as intent_helper
|
||||
|
@ -1473,3 +1477,194 @@ async def test_get_set_configuration(
|
|||
|
||||
# Device should have been updated
|
||||
assert satellite.async_get_configuration() == updated_config
|
||||
|
||||
|
||||
async def test_wake_word_select(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
) -> None:
|
||||
"""Test wake word select."""
|
||||
device_config = AssistSatelliteConfiguration(
|
||||
available_wake_words=[
|
||||
AssistSatelliteWakeWord("okay_nabu", "Okay Nabu", ["en"]),
|
||||
AssistSatelliteWakeWord("hey_jarvis", "Hey Jarvis", ["en"]),
|
||||
AssistSatelliteWakeWord("hey_mycroft", "Hey Mycroft", ["en"]),
|
||||
],
|
||||
active_wake_words=["hey_jarvis"],
|
||||
max_active_wake_words=1,
|
||||
)
|
||||
mock_client.get_voice_assistant_configuration.return_value = device_config
|
||||
|
||||
# Wrap mock so we can tell when it's done
|
||||
configuration_set = asyncio.Event()
|
||||
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Update device config because entity will request it after update
|
||||
device_config.active_wake_words = kwargs["active_wake_words"]
|
||||
configuration_set.set()
|
||||
|
||||
mock_client.set_voice_assistant_configuration = AsyncMock(side_effect=wrapper)
|
||||
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
| VoiceAssistantFeature.ANNOUNCE
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||
assert satellite is not None
|
||||
assert satellite.async_get_configuration().active_wake_words == ["hey_jarvis"]
|
||||
|
||||
# Active wake word should be selected
|
||||
state = hass.states.get("select.test_wake_word")
|
||||
assert state is not None
|
||||
assert state.state == "Hey Jarvis"
|
||||
|
||||
# Changing the select should set the active wake word
|
||||
await hass.services.async_call(
|
||||
SELECT_DOMAIN,
|
||||
SERVICE_SELECT_OPTION,
|
||||
{"entity_id": "select.test_wake_word", "option": "Okay Nabu"},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
state = hass.states.get("select.test_wake_word")
|
||||
assert state is not None
|
||||
assert state.state == "Okay Nabu"
|
||||
|
||||
# Wait for device config to be updated
|
||||
async with asyncio.timeout(1):
|
||||
await configuration_set.wait()
|
||||
|
||||
# Satellite config should have been updated
|
||||
assert satellite.async_get_configuration().active_wake_words == ["okay_nabu"]
|
||||
|
||||
|
||||
async def test_wake_word_select_no_wake_words(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
) -> None:
|
||||
"""Test wake word select is unavailable when there are no available wake word."""
|
||||
device_config = AssistSatelliteConfiguration(
|
||||
available_wake_words=[],
|
||||
active_wake_words=[],
|
||||
max_active_wake_words=1,
|
||||
)
|
||||
mock_client.get_voice_assistant_configuration.return_value = device_config
|
||||
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
| VoiceAssistantFeature.ANNOUNCE
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||
assert satellite is not None
|
||||
assert not satellite.async_get_configuration().available_wake_words
|
||||
|
||||
# Select should be unavailable
|
||||
state = hass.states.get("select.test_wake_word")
|
||||
assert state is not None
|
||||
assert state.state == STATE_UNAVAILABLE
|
||||
|
||||
|
||||
async def test_wake_word_select_zero_max_wake_words(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
) -> None:
|
||||
"""Test wake word select is unavailable max wake words is zero."""
|
||||
device_config = AssistSatelliteConfiguration(
|
||||
available_wake_words=[
|
||||
AssistSatelliteWakeWord("okay_nabu", "Okay Nabu", ["en"]),
|
||||
],
|
||||
active_wake_words=[],
|
||||
max_active_wake_words=0,
|
||||
)
|
||||
mock_client.get_voice_assistant_configuration.return_value = device_config
|
||||
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
| VoiceAssistantFeature.ANNOUNCE
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||
assert satellite is not None
|
||||
assert satellite.async_get_configuration().max_active_wake_words == 0
|
||||
|
||||
# Select should be unavailable
|
||||
state = hass.states.get("select.test_wake_word")
|
||||
assert state is not None
|
||||
assert state.state == STATE_UNAVAILABLE
|
||||
|
||||
|
||||
async def test_wake_word_select_no_active_wake_words(
|
||||
hass: HomeAssistant,
|
||||
mock_client: APIClient,
|
||||
mock_esphome_device: Callable[
|
||||
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||
Awaitable[MockESPHomeDevice],
|
||||
],
|
||||
) -> None:
|
||||
"""Test wake word select uses first available wake word if none are active."""
|
||||
device_config = AssistSatelliteConfiguration(
|
||||
available_wake_words=[
|
||||
AssistSatelliteWakeWord("okay_nabu", "Okay Nabu", ["en"]),
|
||||
AssistSatelliteWakeWord("hey_jarvis", "Hey Jarvis", ["en"]),
|
||||
],
|
||||
active_wake_words=[],
|
||||
max_active_wake_words=1,
|
||||
)
|
||||
mock_client.get_voice_assistant_configuration.return_value = device_config
|
||||
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
entity_info=[],
|
||||
user_service=[],
|
||||
states=[],
|
||||
device_info={
|
||||
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||
| VoiceAssistantFeature.ANNOUNCE
|
||||
},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||
assert satellite is not None
|
||||
assert not satellite.async_get_configuration().active_wake_words
|
||||
|
||||
# First available wake word should be selected
|
||||
state = hass.states.get("select.test_wake_word")
|
||||
assert state is not None
|
||||
assert state.state == "Okay Nabu"
|
||||
|
|
|
@ -9,7 +9,7 @@ from homeassistant.components.select import (
|
|||
DOMAIN as SELECT_DOMAIN,
|
||||
SERVICE_SELECT_OPTION,
|
||||
)
|
||||
from homeassistant.const import ATTR_ENTITY_ID
|
||||
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNAVAILABLE
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
||||
|
@ -38,6 +38,16 @@ async def test_vad_sensitivity_select(
|
|||
assert state.state == "default"
|
||||
|
||||
|
||||
async def test_wake_word_select(
|
||||
hass: HomeAssistant,
|
||||
mock_voice_assistant_v1_entry,
|
||||
) -> None:
|
||||
"""Test that wake word select is unavailable initially."""
|
||||
state = hass.states.get("select.test_wake_word")
|
||||
assert state is not None
|
||||
assert state.state == STATE_UNAVAILABLE
|
||||
|
||||
|
||||
async def test_select_generic_entity(
|
||||
hass: HomeAssistant, mock_client: APIClient, mock_generic_device_entry
|
||||
) -> None:
|
||||
|
|
Loading…
Reference in New Issue