Add Wyoming satellite audio settings (#105261)

* Add noise suppression level

* Add auto gain and volume multiplier

* Always use mock TTS dir in Wyoming tests

* More tests
pull/105265/head
Michael Hansen 2023-12-07 16:02:55 -06:00 committed by GitHub
parent d1aa690c24
commit e9f8e7ab50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 445 additions and 15 deletions

View File

@ -17,7 +17,12 @@ from .satellite import WyomingSatellite
_LOGGER = logging.getLogger(__name__)
SATELLITE_PLATFORMS = [Platform.BINARY_SENSOR, Platform.SELECT, Platform.SWITCH]
SATELLITE_PLATFORMS = [
Platform.BINARY_SENSOR,
Platform.SELECT,
Platform.SWITCH,
Platform.NUMBER,
]
__all__ = [
"ATTR_SPEAKER",

View File

@ -19,10 +19,14 @@ class SatelliteDevice:
is_active: bool = False
is_enabled: bool = True
pipeline_name: str | None = None
noise_suppression_level: int = 0
auto_gain: int = 0
volume_multiplier: float = 1.0
_is_active_listener: Callable[[], None] | None = None
_is_enabled_listener: Callable[[], None] | None = None
_pipeline_listener: Callable[[], None] | None = None
_audio_settings_listener: Callable[[], None] | None = None
@callback
def set_is_active(self, active: bool) -> None:
@ -48,6 +52,30 @@ class SatelliteDevice:
if self._pipeline_listener is not None:
self._pipeline_listener()
@callback
def set_noise_suppression_level(self, noise_suppression_level: int) -> None:
"""Set noise suppression level."""
if noise_suppression_level != self.noise_suppression_level:
self.noise_suppression_level = noise_suppression_level
if self._audio_settings_listener is not None:
self._audio_settings_listener()
@callback
def set_auto_gain(self, auto_gain: int) -> None:
"""Set auto gain amount."""
if auto_gain != self.auto_gain:
self.auto_gain = auto_gain
if self._audio_settings_listener is not None:
self._audio_settings_listener()
@callback
def set_volume_multiplier(self, volume_multiplier: float) -> None:
"""Set auto gain amount."""
if volume_multiplier != self.volume_multiplier:
self.volume_multiplier = volume_multiplier
if self._audio_settings_listener is not None:
self._audio_settings_listener()
@callback
def set_is_active_listener(self, is_active_listener: Callable[[], None]) -> None:
"""Listen for updates to is_active."""
@ -63,6 +91,13 @@ class SatelliteDevice:
"""Listen for updates to pipeline."""
self._pipeline_listener = pipeline_listener
@callback
def set_audio_settings_listener(
self, audio_settings_listener: Callable[[], None]
) -> None:
"""Listen for updates to audio settings."""
self._audio_settings_listener = audio_settings_listener
def get_assist_in_progress_entity_id(self, hass: HomeAssistant) -> str | None:
"""Return entity id for assist in progress binary sensor."""
ent_reg = er.async_get(hass)
@ -83,3 +118,24 @@ class SatelliteDevice:
return ent_reg.async_get_entity_id(
"select", DOMAIN, f"{self.satellite_id}-pipeline"
)
def get_noise_suppression_level_entity_id(self, hass: HomeAssistant) -> str | None:
"""Return entity id for noise suppression select."""
ent_reg = er.async_get(hass)
return ent_reg.async_get_entity_id(
"select", DOMAIN, f"{self.satellite_id}-noise_suppression_level"
)
def get_auto_gain_entity_id(self, hass: HomeAssistant) -> str | None:
"""Return entity id for auto gain amount."""
ent_reg = er.async_get(hass)
return ent_reg.async_get_entity_id(
"number", DOMAIN, f"{self.satellite_id}-auto_gain"
)
def get_volume_multiplier_entity_id(self, hass: HomeAssistant) -> str | None:
"""Return entity id for microphone volume multiplier."""
ent_reg = er.async_get(hass)
return ent_reg.async_get_entity_id(
"number", DOMAIN, f"{self.satellite_id}-volume_multiplier"
)

View File

@ -0,0 +1,102 @@
"""Number entities for Wyoming integration."""
from __future__ import annotations
from typing import TYPE_CHECKING, Final
from homeassistant.components.number import NumberEntityDescription, RestoreNumber
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import EntityCategory
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
from .entity import WyomingSatelliteEntity
if TYPE_CHECKING:
from .models import DomainDataItem
_MAX_AUTO_GAIN: Final = 31
_MIN_VOLUME_MULTIPLIER: Final = 0.1
_MAX_VOLUME_MULTIPLIER: Final = 10.0
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up Wyoming number entities."""
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
# Setup is only forwarded for satellites
assert item.satellite is not None
device = item.satellite.device
async_add_entities(
[
WyomingSatelliteAutoGainNumber(device),
WyomingSatelliteVolumeMultiplierNumber(device),
]
)
class WyomingSatelliteAutoGainNumber(WyomingSatelliteEntity, RestoreNumber):
"""Entity to represent auto gain amount."""
entity_description = NumberEntityDescription(
key="auto_gain",
translation_key="auto_gain",
entity_category=EntityCategory.CONFIG,
)
_attr_should_poll = False
_attr_native_min_value = 0
_attr_native_max_value = _MAX_AUTO_GAIN
_attr_native_value = 0
async def async_added_to_hass(self) -> None:
"""When entity is added to Home Assistant."""
await super().async_added_to_hass()
state = await self.async_get_last_state()
if state is not None:
await self.async_set_native_value(float(state.state))
async def async_set_native_value(self, value: float) -> None:
"""Set new value."""
auto_gain = int(max(0, min(_MAX_AUTO_GAIN, value)))
self._attr_native_value = auto_gain
self.async_write_ha_state()
self._device.set_auto_gain(auto_gain)
class WyomingSatelliteVolumeMultiplierNumber(WyomingSatelliteEntity, RestoreNumber):
"""Entity to represent microphone volume multiplier."""
entity_description = NumberEntityDescription(
key="volume_multiplier",
translation_key="volume_multiplier",
entity_category=EntityCategory.CONFIG,
)
_attr_should_poll = False
_attr_native_min_value = _MIN_VOLUME_MULTIPLIER
_attr_native_max_value = _MAX_VOLUME_MULTIPLIER
_attr_native_step = 0.1
_attr_native_value = 1.0
async def async_added_to_hass(self) -> None:
"""When entity is added to Home Assistant."""
await super().async_added_to_hass()
last_number_data = await self.async_get_last_number_data()
if (last_number_data is not None) and (
last_number_data.native_value is not None
):
await self.async_set_native_value(last_number_data.native_value)
async def async_set_native_value(self, value: float) -> None:
"""Set new value."""
self._attr_native_value = float(
max(_MIN_VOLUME_MULTIPLIER, min(_MAX_VOLUME_MULTIPLIER, value))
)
self.async_write_ha_state()
self._device.set_volume_multiplier(self._attr_native_value)

View File

@ -60,6 +60,7 @@ class WyomingSatellite:
self.device.set_is_enabled_listener(self._enabled_changed)
self.device.set_pipeline_listener(self._pipeline_changed)
self.device.set_audio_settings_listener(self._audio_settings_changed)
async def run(self) -> None:
"""Run and maintain a connection to satellite."""
@ -135,6 +136,12 @@ class WyomingSatellite:
# Cancel any running pipeline
self._audio_queue.put_nowait(None)
def _audio_settings_changed(self) -> None:
"""Run when device audio settings."""
# Cancel any running pipeline
self._audio_queue.put_nowait(None)
async def _run_once(self) -> None:
"""Run pipelines until an error occurs."""
self.device.set_is_active(False)
@ -227,6 +234,11 @@ class WyomingSatellite:
end_stage=end_stage,
tts_audio_output="wav",
pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(
noise_suppression_level=self.device.noise_suppression_level,
auto_gain_dbfs=self.device.auto_gain,
volume_multiplier=self.device.volume_multiplier,
),
)
)

View File

@ -1,12 +1,15 @@
"""Select entities for VoIP integration."""
"""Select entities for Wyoming integration."""
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Final
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect
from homeassistant.components.select import SelectEntity, SelectEntityDescription
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import EntityCategory
from homeassistant.core import HomeAssistant
from homeassistant.helpers import restore_state
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
@ -16,19 +19,34 @@ from .entity import WyomingSatelliteEntity
if TYPE_CHECKING:
from .models import DomainDataItem
_NOISE_SUPPRESSION_LEVEL: Final = {
"off": 0,
"low": 1,
"medium": 2,
"high": 3,
"max": 4,
}
_DEFAULT_NOISE_SUPPRESSION_LEVEL: Final = "off"
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up VoIP switch entities."""
"""Set up Wyoming select entities."""
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
# Setup is only forwarded for satellites
assert item.satellite is not None
async_add_entities([WyomingSatellitePipelineSelect(hass, item.satellite.device)])
device = item.satellite.device
async_add_entities(
[
WyomingSatellitePipelineSelect(hass, device),
WyomingSatelliteNoiseSuppressionLevelSelect(device),
]
)
class WyomingSatellitePipelineSelect(WyomingSatelliteEntity, AssistPipelineSelect):
@ -45,3 +63,32 @@ class WyomingSatellitePipelineSelect(WyomingSatelliteEntity, AssistPipelineSelec
"""Select an option."""
await super().async_select_option(option)
self.device.set_pipeline_name(option)
class WyomingSatelliteNoiseSuppressionLevelSelect(
WyomingSatelliteEntity, SelectEntity, restore_state.RestoreEntity
):
"""Entity to represent noise suppression level setting."""
entity_description = SelectEntityDescription(
key="noise_suppression_level",
translation_key="noise_suppression_level",
entity_category=EntityCategory.CONFIG,
)
_attr_should_poll = False
_attr_current_option = _DEFAULT_NOISE_SUPPRESSION_LEVEL
_attr_options = list(_NOISE_SUPPRESSION_LEVEL.keys())
async def async_added_to_hass(self) -> None:
"""When entity is added to Home Assistant."""
await super().async_added_to_hass()
state = await self.async_get_last_state()
if state is not None and state.state in self.options:
self._attr_current_option = state.state
async def async_select_option(self, option: str) -> None:
"""Select an option."""
self._attr_current_option = option
self.async_write_ha_state()
self._device.set_noise_suppression_level(_NOISE_SUPPRESSION_LEVEL[option])

View File

@ -37,14 +37,29 @@
"preferred": "[%key:component::assist_pipeline::entity::select::pipeline::state::preferred%]"
}
},
"noise_suppression": {
"name": "Noise suppression"
"noise_suppression_level": {
"name": "Noise suppression level",
"state": {
"off": "Off",
"low": "Low",
"medium": "Medium",
"high": "High",
"max": "Max"
}
}
},
"switch": {
"satellite_enabled": {
"name": "Satellite enabled"
}
},
"number": {
"auto_gain": {
"name": "Auto gain"
},
"volume_multiplier": {
"name": "Mic volume"
}
}
}
}

View File

@ -1,5 +1,6 @@
"""Tests for the Wyoming integration."""
import asyncio
from unittest.mock import patch
from wyoming.event import Event
from wyoming.info import (
@ -15,6 +16,10 @@ from wyoming.info import (
WakeProgram,
)
from homeassistant.components.wyoming import DOMAIN
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.core import HomeAssistant
TEST_ATTR = Attribution(name="Test", url="http://www.test.com")
STT_INFO = Info(
asr=[
@ -124,3 +129,19 @@ class MockAsyncTcpClient:
self.host = host
self.port = port
return self
async def reload_satellite(
hass: HomeAssistant, config_entry_id: str
) -> SatelliteDevice:
"""Reload config entry with satellite info and returns new device."""
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
), patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.run"
) as _run_mock:
# _run_mock: satellite task does not actually run
await hass.config_entries.async_reload(config_entry_id)
return hass.data[DOMAIN][config_entry_id].satellite.device

View File

@ -16,6 +16,12 @@ from . import SATELLITE_INFO, STT_INFO, TTS_INFO, WAKE_WORD_INFO
from tests.common import MockConfigEntry
@pytest.fixture(autouse=True)
def mock_tts_cache_dir_autouse(mock_tts_cache_dir):
"""Mock the TTS cache dir with empty dir."""
return mock_tts_cache_dir
@pytest.fixture(autouse=True)
async def init_components(hass: HomeAssistant):
"""Set up required components."""

View File

@ -4,6 +4,8 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant
from . import reload_satellite
async def test_assist_in_progress(
hass: HomeAssistant,
@ -26,7 +28,8 @@ async def test_assist_in_progress(
assert state.state == STATE_ON
assert satellite_device.is_active
satellite_device.set_is_active(False)
# test restore does *not* happen
satellite_device = await reload_satellite(hass, satellite_config_entry.entry_id)
state = hass.states.get(assist_in_progress_id)
assert state is not None

View File

@ -0,0 +1,102 @@
"""Test Wyoming number."""
from unittest.mock import patch
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from . import reload_satellite
async def test_auto_gain_number(
hass: HomeAssistant,
satellite_config_entry: ConfigEntry,
satellite_device: SatelliteDevice,
) -> None:
"""Test automatic gain control number."""
agc_entity_id = satellite_device.get_auto_gain_entity_id(hass)
assert agc_entity_id
state = hass.states.get(agc_entity_id)
assert state is not None
assert int(state.state) == 0
assert satellite_device.auto_gain == 0
# Change setting
with patch.object(satellite_device, "set_auto_gain") as mock_agc_changed:
await hass.services.async_call(
"number",
"set_value",
{"entity_id": agc_entity_id, "value": 31},
blocking=True,
)
state = hass.states.get(agc_entity_id)
assert state is not None
assert int(state.state) == 31
# set function should have been called
mock_agc_changed.assert_called_once_with(31)
# test restore
satellite_device = await reload_satellite(hass, satellite_config_entry.entry_id)
state = hass.states.get(agc_entity_id)
assert state is not None
assert int(state.state) == 31
await hass.services.async_call(
"number",
"set_value",
{"entity_id": agc_entity_id, "value": 15},
blocking=True,
)
assert satellite_device.auto_gain == 15
async def test_volume_multiplier_number(
hass: HomeAssistant,
satellite_config_entry: ConfigEntry,
satellite_device: SatelliteDevice,
) -> None:
"""Test volume multiplier number."""
vm_entity_id = satellite_device.get_volume_multiplier_entity_id(hass)
assert vm_entity_id
state = hass.states.get(vm_entity_id)
assert state is not None
assert float(state.state) == 1.0
assert satellite_device.volume_multiplier == 1.0
# Change setting
with patch.object(satellite_device, "set_volume_multiplier") as mock_vm_changed:
await hass.services.async_call(
"number",
"set_value",
{"entity_id": vm_entity_id, "value": 2.0},
blocking=True,
)
state = hass.states.get(vm_entity_id)
assert state is not None
assert float(state.state) == 2.0
# set function should have been called
mock_vm_changed.assert_called_once_with(2.0)
# test restore
satellite_device = await reload_satellite(hass, satellite_config_entry.entry_id)
state = hass.states.get(vm_entity_id)
assert state is not None
assert float(state.state) == 2.0
await hass.services.async_call(
"number",
"set_value",
{"entity_id": vm_entity_id, "value": 0.5},
blocking=True,
)
assert float(satellite_device.volume_multiplier) == 0.5

View File

@ -9,6 +9,8 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from . import reload_satellite
async def test_pipeline_select(
hass: HomeAssistant,
@ -61,9 +63,16 @@ async def test_pipeline_select(
assert state is not None
assert state.state == "Test 1"
# async_pipeline_changed should have been called
# set function should have been called
mock_pipeline_changed.assert_called_once_with("Test 1")
# test restore
satellite_device = await reload_satellite(hass, satellite_config_entry.entry_id)
state = hass.states.get(pipeline_entity_id)
assert state is not None
assert state.state == "Test 1"
# Change back and check update listener
pipeline_listener = Mock()
satellite_device.set_pipeline_listener(pipeline_listener)
@ -81,3 +90,52 @@ async def test_pipeline_select(
# listener should have been called
pipeline_listener.assert_called_once()
async def test_noise_suppression_level_select(
hass: HomeAssistant,
satellite_config_entry: ConfigEntry,
satellite_device: SatelliteDevice,
) -> None:
"""Test noise suppression level select."""
nsl_entity_id = satellite_device.get_noise_suppression_level_entity_id(hass)
assert nsl_entity_id
state = hass.states.get(nsl_entity_id)
assert state is not None
assert state.state == "off"
assert satellite_device.noise_suppression_level == 0
# Change setting
with patch.object(
satellite_device, "set_noise_suppression_level"
) as mock_nsl_changed:
await hass.services.async_call(
"select",
"select_option",
{"entity_id": nsl_entity_id, "option": "max"},
blocking=True,
)
state = hass.states.get(nsl_entity_id)
assert state is not None
assert state.state == "max"
# set function should have been called
mock_nsl_changed.assert_called_once_with(4)
# test restore
satellite_device = await reload_satellite(hass, satellite_config_entry.entry_id)
state = hass.states.get(nsl_entity_id)
assert state is not None
assert state.state == "max"
await hass.services.async_call(
"select",
"select_option",
{"entity_id": nsl_entity_id, "option": "medium"},
blocking=True,
)
assert satellite_device.noise_suppression_level == 2

View File

@ -4,6 +4,8 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant
from . import reload_satellite
async def test_satellite_enabled(
hass: HomeAssistant,
@ -30,3 +32,10 @@ async def test_satellite_enabled(
assert state is not None
assert state.state == STATE_OFF
assert not satellite_device.is_enabled
# test restore
satellite_device = await reload_satellite(hass, satellite_config_entry.entry_id)
state = hass.states.get(satellite_enabled_id)
assert state is not None
assert state.state == STATE_OFF

View File

@ -16,12 +16,6 @@ from homeassistant.helpers.entity_component import DATA_INSTANCES
from . import MockAsyncTcpClient
@pytest.fixture(autouse=True)
def mock_tts_cache_dir_autouse(mock_tts_cache_dir):
"""Mock the TTS cache dir with empty dir."""
return mock_tts_cache_dir
async def test_support(hass: HomeAssistant, init_wyoming_tts) -> None:
"""Test supported properties."""
state = hass.states.get("tts.test_tts")