From e9f8e7ab5093039c9cf2ab0e9a4f4d1f98418aa4 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Thu, 7 Dec 2023 16:02:55 -0600 Subject: [PATCH] Add Wyoming satellite audio settings (#105261) * Add noise suppression level * Add auto gain and volume multiplier * Always use mock TTS dir in Wyoming tests * More tests --- homeassistant/components/wyoming/__init__.py | 7 +- homeassistant/components/wyoming/devices.py | 56 ++++++++++ homeassistant/components/wyoming/number.py | 102 ++++++++++++++++++ homeassistant/components/wyoming/satellite.py | 12 +++ homeassistant/components/wyoming/select.py | 55 +++++++++- homeassistant/components/wyoming/strings.json | 19 +++- tests/components/wyoming/__init__.py | 21 ++++ tests/components/wyoming/conftest.py | 6 ++ .../components/wyoming/test_binary_sensor.py | 5 +- tests/components/wyoming/test_number.py | 102 ++++++++++++++++++ tests/components/wyoming/test_select.py | 60 ++++++++++- tests/components/wyoming/test_switch.py | 9 ++ tests/components/wyoming/test_tts.py | 6 -- 13 files changed, 445 insertions(+), 15 deletions(-) create mode 100644 homeassistant/components/wyoming/number.py create mode 100644 tests/components/wyoming/test_number.py diff --git a/homeassistant/components/wyoming/__init__.py b/homeassistant/components/wyoming/__init__.py index 2cc9b7050a0..88e490d6dc9 100644 --- a/homeassistant/components/wyoming/__init__.py +++ b/homeassistant/components/wyoming/__init__.py @@ -17,7 +17,12 @@ from .satellite import WyomingSatellite _LOGGER = logging.getLogger(__name__) -SATELLITE_PLATFORMS = [Platform.BINARY_SENSOR, Platform.SELECT, Platform.SWITCH] +SATELLITE_PLATFORMS = [ + Platform.BINARY_SENSOR, + Platform.SELECT, + Platform.SWITCH, + Platform.NUMBER, +] __all__ = [ "ATTR_SPEAKER", diff --git a/homeassistant/components/wyoming/devices.py b/homeassistant/components/wyoming/devices.py index 90dad889707..bd7252bcf6b 100644 --- a/homeassistant/components/wyoming/devices.py +++ b/homeassistant/components/wyoming/devices.py @@ -19,10 +19,14 @@ class SatelliteDevice: is_active: bool = False is_enabled: bool = True pipeline_name: str | None = None + noise_suppression_level: int = 0 + auto_gain: int = 0 + volume_multiplier: float = 1.0 _is_active_listener: Callable[[], None] | None = None _is_enabled_listener: Callable[[], None] | None = None _pipeline_listener: Callable[[], None] | None = None + _audio_settings_listener: Callable[[], None] | None = None @callback def set_is_active(self, active: bool) -> None: @@ -48,6 +52,30 @@ class SatelliteDevice: if self._pipeline_listener is not None: self._pipeline_listener() + @callback + def set_noise_suppression_level(self, noise_suppression_level: int) -> None: + """Set noise suppression level.""" + if noise_suppression_level != self.noise_suppression_level: + self.noise_suppression_level = noise_suppression_level + if self._audio_settings_listener is not None: + self._audio_settings_listener() + + @callback + def set_auto_gain(self, auto_gain: int) -> None: + """Set auto gain amount.""" + if auto_gain != self.auto_gain: + self.auto_gain = auto_gain + if self._audio_settings_listener is not None: + self._audio_settings_listener() + + @callback + def set_volume_multiplier(self, volume_multiplier: float) -> None: + """Set auto gain amount.""" + if volume_multiplier != self.volume_multiplier: + self.volume_multiplier = volume_multiplier + if self._audio_settings_listener is not None: + self._audio_settings_listener() + @callback def set_is_active_listener(self, is_active_listener: Callable[[], None]) -> None: """Listen for updates to is_active.""" @@ -63,6 +91,13 @@ class SatelliteDevice: """Listen for updates to pipeline.""" self._pipeline_listener = pipeline_listener + @callback + def set_audio_settings_listener( + self, audio_settings_listener: Callable[[], None] + ) -> None: + """Listen for updates to audio settings.""" + self._audio_settings_listener = audio_settings_listener + def get_assist_in_progress_entity_id(self, hass: HomeAssistant) -> str | None: """Return entity id for assist in progress binary sensor.""" ent_reg = er.async_get(hass) @@ -83,3 +118,24 @@ class SatelliteDevice: return ent_reg.async_get_entity_id( "select", DOMAIN, f"{self.satellite_id}-pipeline" ) + + def get_noise_suppression_level_entity_id(self, hass: HomeAssistant) -> str | None: + """Return entity id for noise suppression select.""" + ent_reg = er.async_get(hass) + return ent_reg.async_get_entity_id( + "select", DOMAIN, f"{self.satellite_id}-noise_suppression_level" + ) + + def get_auto_gain_entity_id(self, hass: HomeAssistant) -> str | None: + """Return entity id for auto gain amount.""" + ent_reg = er.async_get(hass) + return ent_reg.async_get_entity_id( + "number", DOMAIN, f"{self.satellite_id}-auto_gain" + ) + + def get_volume_multiplier_entity_id(self, hass: HomeAssistant) -> str | None: + """Return entity id for microphone volume multiplier.""" + ent_reg = er.async_get(hass) + return ent_reg.async_get_entity_id( + "number", DOMAIN, f"{self.satellite_id}-volume_multiplier" + ) diff --git a/homeassistant/components/wyoming/number.py b/homeassistant/components/wyoming/number.py new file mode 100644 index 00000000000..5e769eeb06d --- /dev/null +++ b/homeassistant/components/wyoming/number.py @@ -0,0 +1,102 @@ +"""Number entities for Wyoming integration.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Final + +from homeassistant.components.number import NumberEntityDescription, RestoreNumber +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import EntityCategory +from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity_platform import AddEntitiesCallback + +from .const import DOMAIN +from .entity import WyomingSatelliteEntity + +if TYPE_CHECKING: + from .models import DomainDataItem + +_MAX_AUTO_GAIN: Final = 31 +_MIN_VOLUME_MULTIPLIER: Final = 0.1 +_MAX_VOLUME_MULTIPLIER: Final = 10.0 + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up Wyoming number entities.""" + item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] + + # Setup is only forwarded for satellites + assert item.satellite is not None + + device = item.satellite.device + async_add_entities( + [ + WyomingSatelliteAutoGainNumber(device), + WyomingSatelliteVolumeMultiplierNumber(device), + ] + ) + + +class WyomingSatelliteAutoGainNumber(WyomingSatelliteEntity, RestoreNumber): + """Entity to represent auto gain amount.""" + + entity_description = NumberEntityDescription( + key="auto_gain", + translation_key="auto_gain", + entity_category=EntityCategory.CONFIG, + ) + _attr_should_poll = False + _attr_native_min_value = 0 + _attr_native_max_value = _MAX_AUTO_GAIN + _attr_native_value = 0 + + async def async_added_to_hass(self) -> None: + """When entity is added to Home Assistant.""" + await super().async_added_to_hass() + + state = await self.async_get_last_state() + if state is not None: + await self.async_set_native_value(float(state.state)) + + async def async_set_native_value(self, value: float) -> None: + """Set new value.""" + auto_gain = int(max(0, min(_MAX_AUTO_GAIN, value))) + self._attr_native_value = auto_gain + self.async_write_ha_state() + self._device.set_auto_gain(auto_gain) + + +class WyomingSatelliteVolumeMultiplierNumber(WyomingSatelliteEntity, RestoreNumber): + """Entity to represent microphone volume multiplier.""" + + entity_description = NumberEntityDescription( + key="volume_multiplier", + translation_key="volume_multiplier", + entity_category=EntityCategory.CONFIG, + ) + _attr_should_poll = False + _attr_native_min_value = _MIN_VOLUME_MULTIPLIER + _attr_native_max_value = _MAX_VOLUME_MULTIPLIER + _attr_native_step = 0.1 + _attr_native_value = 1.0 + + async def async_added_to_hass(self) -> None: + """When entity is added to Home Assistant.""" + await super().async_added_to_hass() + last_number_data = await self.async_get_last_number_data() + if (last_number_data is not None) and ( + last_number_data.native_value is not None + ): + await self.async_set_native_value(last_number_data.native_value) + + async def async_set_native_value(self, value: float) -> None: + """Set new value.""" + self._attr_native_value = float( + max(_MIN_VOLUME_MULTIPLIER, min(_MAX_VOLUME_MULTIPLIER, value)) + ) + self.async_write_ha_state() + self._device.set_volume_multiplier(self._attr_native_value) diff --git a/homeassistant/components/wyoming/satellite.py b/homeassistant/components/wyoming/satellite.py index caf65db115e..1cc3fde2a9c 100644 --- a/homeassistant/components/wyoming/satellite.py +++ b/homeassistant/components/wyoming/satellite.py @@ -60,6 +60,7 @@ class WyomingSatellite: self.device.set_is_enabled_listener(self._enabled_changed) self.device.set_pipeline_listener(self._pipeline_changed) + self.device.set_audio_settings_listener(self._audio_settings_changed) async def run(self) -> None: """Run and maintain a connection to satellite.""" @@ -135,6 +136,12 @@ class WyomingSatellite: # Cancel any running pipeline self._audio_queue.put_nowait(None) + def _audio_settings_changed(self) -> None: + """Run when device audio settings.""" + + # Cancel any running pipeline + self._audio_queue.put_nowait(None) + async def _run_once(self) -> None: """Run pipelines until an error occurs.""" self.device.set_is_active(False) @@ -227,6 +234,11 @@ class WyomingSatellite: end_stage=end_stage, tts_audio_output="wav", pipeline_id=pipeline_id, + audio_settings=assist_pipeline.AudioSettings( + noise_suppression_level=self.device.noise_suppression_level, + auto_gain_dbfs=self.device.auto_gain, + volume_multiplier=self.device.volume_multiplier, + ), ) ) diff --git a/homeassistant/components/wyoming/select.py b/homeassistant/components/wyoming/select.py index 2929ae79fa0..c04bad4bef8 100644 --- a/homeassistant/components/wyoming/select.py +++ b/homeassistant/components/wyoming/select.py @@ -1,12 +1,15 @@ -"""Select entities for VoIP integration.""" +"""Select entities for Wyoming integration.""" from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Final from homeassistant.components.assist_pipeline.select import AssistPipelineSelect +from homeassistant.components.select import SelectEntity, SelectEntityDescription from homeassistant.config_entries import ConfigEntry +from homeassistant.const import EntityCategory from homeassistant.core import HomeAssistant +from homeassistant.helpers import restore_state from homeassistant.helpers.entity_platform import AddEntitiesCallback from .const import DOMAIN @@ -16,19 +19,34 @@ from .entity import WyomingSatelliteEntity if TYPE_CHECKING: from .models import DomainDataItem +_NOISE_SUPPRESSION_LEVEL: Final = { + "off": 0, + "low": 1, + "medium": 2, + "high": 3, + "max": 4, +} +_DEFAULT_NOISE_SUPPRESSION_LEVEL: Final = "off" + async def async_setup_entry( hass: HomeAssistant, config_entry: ConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: - """Set up VoIP switch entities.""" + """Set up Wyoming select entities.""" item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id] # Setup is only forwarded for satellites assert item.satellite is not None - async_add_entities([WyomingSatellitePipelineSelect(hass, item.satellite.device)]) + device = item.satellite.device + async_add_entities( + [ + WyomingSatellitePipelineSelect(hass, device), + WyomingSatelliteNoiseSuppressionLevelSelect(device), + ] + ) class WyomingSatellitePipelineSelect(WyomingSatelliteEntity, AssistPipelineSelect): @@ -45,3 +63,32 @@ class WyomingSatellitePipelineSelect(WyomingSatelliteEntity, AssistPipelineSelec """Select an option.""" await super().async_select_option(option) self.device.set_pipeline_name(option) + + +class WyomingSatelliteNoiseSuppressionLevelSelect( + WyomingSatelliteEntity, SelectEntity, restore_state.RestoreEntity +): + """Entity to represent noise suppression level setting.""" + + entity_description = SelectEntityDescription( + key="noise_suppression_level", + translation_key="noise_suppression_level", + entity_category=EntityCategory.CONFIG, + ) + _attr_should_poll = False + _attr_current_option = _DEFAULT_NOISE_SUPPRESSION_LEVEL + _attr_options = list(_NOISE_SUPPRESSION_LEVEL.keys()) + + async def async_added_to_hass(self) -> None: + """When entity is added to Home Assistant.""" + await super().async_added_to_hass() + + state = await self.async_get_last_state() + if state is not None and state.state in self.options: + self._attr_current_option = state.state + + async def async_select_option(self, option: str) -> None: + """Select an option.""" + self._attr_current_option = option + self.async_write_ha_state() + self._device.set_noise_suppression_level(_NOISE_SUPPRESSION_LEVEL[option]) diff --git a/homeassistant/components/wyoming/strings.json b/homeassistant/components/wyoming/strings.json index 19b6a513d4b..7b6be68aeb2 100644 --- a/homeassistant/components/wyoming/strings.json +++ b/homeassistant/components/wyoming/strings.json @@ -37,14 +37,29 @@ "preferred": "[%key:component::assist_pipeline::entity::select::pipeline::state::preferred%]" } }, - "noise_suppression": { - "name": "Noise suppression" + "noise_suppression_level": { + "name": "Noise suppression level", + "state": { + "off": "Off", + "low": "Low", + "medium": "Medium", + "high": "High", + "max": "Max" + } } }, "switch": { "satellite_enabled": { "name": "Satellite enabled" } + }, + "number": { + "auto_gain": { + "name": "Auto gain" + }, + "volume_multiplier": { + "name": "Mic volume" + } } } } diff --git a/tests/components/wyoming/__init__.py b/tests/components/wyoming/__init__.py index 899eda7ec1a..268ebef1d06 100644 --- a/tests/components/wyoming/__init__.py +++ b/tests/components/wyoming/__init__.py @@ -1,5 +1,6 @@ """Tests for the Wyoming integration.""" import asyncio +from unittest.mock import patch from wyoming.event import Event from wyoming.info import ( @@ -15,6 +16,10 @@ from wyoming.info import ( WakeProgram, ) +from homeassistant.components.wyoming import DOMAIN +from homeassistant.components.wyoming.devices import SatelliteDevice +from homeassistant.core import HomeAssistant + TEST_ATTR = Attribution(name="Test", url="http://www.test.com") STT_INFO = Info( asr=[ @@ -124,3 +129,19 @@ class MockAsyncTcpClient: self.host = host self.port = port return self + + +async def reload_satellite( + hass: HomeAssistant, config_entry_id: str +) -> SatelliteDevice: + """Reload config entry with satellite info and returns new device.""" + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), patch( + "homeassistant.components.wyoming.satellite.WyomingSatellite.run" + ) as _run_mock: + # _run_mock: satellite task does not actually run + await hass.config_entries.async_reload(config_entry_id) + + return hass.data[DOMAIN][config_entry_id].satellite.device diff --git a/tests/components/wyoming/conftest.py b/tests/components/wyoming/conftest.py index a30c1048eb6..f22ec7e9e16 100644 --- a/tests/components/wyoming/conftest.py +++ b/tests/components/wyoming/conftest.py @@ -16,6 +16,12 @@ from . import SATELLITE_INFO, STT_INFO, TTS_INFO, WAKE_WORD_INFO from tests.common import MockConfigEntry +@pytest.fixture(autouse=True) +def mock_tts_cache_dir_autouse(mock_tts_cache_dir): + """Mock the TTS cache dir with empty dir.""" + return mock_tts_cache_dir + + @pytest.fixture(autouse=True) async def init_components(hass: HomeAssistant): """Set up required components.""" diff --git a/tests/components/wyoming/test_binary_sensor.py b/tests/components/wyoming/test_binary_sensor.py index 27294186a90..fba181a63ca 100644 --- a/tests/components/wyoming/test_binary_sensor.py +++ b/tests/components/wyoming/test_binary_sensor.py @@ -4,6 +4,8 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.const import STATE_OFF, STATE_ON from homeassistant.core import HomeAssistant +from . import reload_satellite + async def test_assist_in_progress( hass: HomeAssistant, @@ -26,7 +28,8 @@ async def test_assist_in_progress( assert state.state == STATE_ON assert satellite_device.is_active - satellite_device.set_is_active(False) + # test restore does *not* happen + satellite_device = await reload_satellite(hass, satellite_config_entry.entry_id) state = hass.states.get(assist_in_progress_id) assert state is not None diff --git a/tests/components/wyoming/test_number.py b/tests/components/wyoming/test_number.py new file mode 100644 index 00000000000..084021d61a7 --- /dev/null +++ b/tests/components/wyoming/test_number.py @@ -0,0 +1,102 @@ +"""Test Wyoming number.""" +from unittest.mock import patch + +from homeassistant.components.wyoming.devices import SatelliteDevice +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant + +from . import reload_satellite + + +async def test_auto_gain_number( + hass: HomeAssistant, + satellite_config_entry: ConfigEntry, + satellite_device: SatelliteDevice, +) -> None: + """Test automatic gain control number.""" + agc_entity_id = satellite_device.get_auto_gain_entity_id(hass) + assert agc_entity_id + + state = hass.states.get(agc_entity_id) + assert state is not None + assert int(state.state) == 0 + assert satellite_device.auto_gain == 0 + + # Change setting + with patch.object(satellite_device, "set_auto_gain") as mock_agc_changed: + await hass.services.async_call( + "number", + "set_value", + {"entity_id": agc_entity_id, "value": 31}, + blocking=True, + ) + + state = hass.states.get(agc_entity_id) + assert state is not None + assert int(state.state) == 31 + + # set function should have been called + mock_agc_changed.assert_called_once_with(31) + + # test restore + satellite_device = await reload_satellite(hass, satellite_config_entry.entry_id) + + state = hass.states.get(agc_entity_id) + assert state is not None + assert int(state.state) == 31 + + await hass.services.async_call( + "number", + "set_value", + {"entity_id": agc_entity_id, "value": 15}, + blocking=True, + ) + + assert satellite_device.auto_gain == 15 + + +async def test_volume_multiplier_number( + hass: HomeAssistant, + satellite_config_entry: ConfigEntry, + satellite_device: SatelliteDevice, +) -> None: + """Test volume multiplier number.""" + vm_entity_id = satellite_device.get_volume_multiplier_entity_id(hass) + assert vm_entity_id + + state = hass.states.get(vm_entity_id) + assert state is not None + assert float(state.state) == 1.0 + assert satellite_device.volume_multiplier == 1.0 + + # Change setting + with patch.object(satellite_device, "set_volume_multiplier") as mock_vm_changed: + await hass.services.async_call( + "number", + "set_value", + {"entity_id": vm_entity_id, "value": 2.0}, + blocking=True, + ) + + state = hass.states.get(vm_entity_id) + assert state is not None + assert float(state.state) == 2.0 + + # set function should have been called + mock_vm_changed.assert_called_once_with(2.0) + + # test restore + satellite_device = await reload_satellite(hass, satellite_config_entry.entry_id) + + state = hass.states.get(vm_entity_id) + assert state is not None + assert float(state.state) == 2.0 + + await hass.services.async_call( + "number", + "set_value", + {"entity_id": vm_entity_id, "value": 0.5}, + blocking=True, + ) + + assert float(satellite_device.volume_multiplier) == 0.5 diff --git a/tests/components/wyoming/test_select.py b/tests/components/wyoming/test_select.py index cab699336fb..128aab57a1a 100644 --- a/tests/components/wyoming/test_select.py +++ b/tests/components/wyoming/test_select.py @@ -9,6 +9,8 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.setup import async_setup_component +from . import reload_satellite + async def test_pipeline_select( hass: HomeAssistant, @@ -61,9 +63,16 @@ async def test_pipeline_select( assert state is not None assert state.state == "Test 1" - # async_pipeline_changed should have been called + # set function should have been called mock_pipeline_changed.assert_called_once_with("Test 1") + # test restore + satellite_device = await reload_satellite(hass, satellite_config_entry.entry_id) + + state = hass.states.get(pipeline_entity_id) + assert state is not None + assert state.state == "Test 1" + # Change back and check update listener pipeline_listener = Mock() satellite_device.set_pipeline_listener(pipeline_listener) @@ -81,3 +90,52 @@ async def test_pipeline_select( # listener should have been called pipeline_listener.assert_called_once() + + +async def test_noise_suppression_level_select( + hass: HomeAssistant, + satellite_config_entry: ConfigEntry, + satellite_device: SatelliteDevice, +) -> None: + """Test noise suppression level select.""" + nsl_entity_id = satellite_device.get_noise_suppression_level_entity_id(hass) + assert nsl_entity_id + + state = hass.states.get(nsl_entity_id) + assert state is not None + assert state.state == "off" + assert satellite_device.noise_suppression_level == 0 + + # Change setting + with patch.object( + satellite_device, "set_noise_suppression_level" + ) as mock_nsl_changed: + await hass.services.async_call( + "select", + "select_option", + {"entity_id": nsl_entity_id, "option": "max"}, + blocking=True, + ) + + state = hass.states.get(nsl_entity_id) + assert state is not None + assert state.state == "max" + + # set function should have been called + mock_nsl_changed.assert_called_once_with(4) + + # test restore + satellite_device = await reload_satellite(hass, satellite_config_entry.entry_id) + + state = hass.states.get(nsl_entity_id) + assert state is not None + assert state.state == "max" + + await hass.services.async_call( + "select", + "select_option", + {"entity_id": nsl_entity_id, "option": "medium"}, + blocking=True, + ) + + assert satellite_device.noise_suppression_level == 2 diff --git a/tests/components/wyoming/test_switch.py b/tests/components/wyoming/test_switch.py index 0b05724d761..a39b7087f6d 100644 --- a/tests/components/wyoming/test_switch.py +++ b/tests/components/wyoming/test_switch.py @@ -4,6 +4,8 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.const import STATE_OFF, STATE_ON from homeassistant.core import HomeAssistant +from . import reload_satellite + async def test_satellite_enabled( hass: HomeAssistant, @@ -30,3 +32,10 @@ async def test_satellite_enabled( assert state is not None assert state.state == STATE_OFF assert not satellite_device.is_enabled + + # test restore + satellite_device = await reload_satellite(hass, satellite_config_entry.entry_id) + + state = hass.states.get(satellite_enabled_id) + assert state is not None + assert state.state == STATE_OFF diff --git a/tests/components/wyoming/test_tts.py b/tests/components/wyoming/test_tts.py index 2f2a25558e4..301074e8ffb 100644 --- a/tests/components/wyoming/test_tts.py +++ b/tests/components/wyoming/test_tts.py @@ -16,12 +16,6 @@ from homeassistant.helpers.entity_component import DATA_INSTANCES from . import MockAsyncTcpClient -@pytest.fixture(autouse=True) -def mock_tts_cache_dir_autouse(mock_tts_cache_dir): - """Mock the TTS cache dir with empty dir.""" - return mock_tts_cache_dir - - async def test_support(hass: HomeAssistant, init_wyoming_tts) -> None: """Test supported properties.""" state = hass.states.get("tts.test_tts")