Add Wyoming satellite (#104759)
* First draft of Wyoming satellite * Set up homeassistant in tests * Move satellite * Add devices with binary sensor and select * Add more events * Add satellite enabled switch * Fix mistake * Only set up necessary platforms for satellites * Lots of fixes * Add tests * Use config entry id as satellite id * Initial satellite test * Add satellite pipeline test * More tests * More satellite tests * Only support single device per config entry * Address comments * Make a copy of platformspull/105135/head
parent
db6b804298
commit
5a49e1dd5c
|
@ -4,17 +4,26 @@ from __future__ import annotations
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.const import Platform
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.exceptions import ConfigEntryNotReady
|
from homeassistant.exceptions import ConfigEntryNotReady
|
||||||
|
from homeassistant.helpers import device_registry as dr
|
||||||
|
|
||||||
from .const import ATTR_SPEAKER, DOMAIN
|
from .const import ATTR_SPEAKER, DOMAIN
|
||||||
from .data import WyomingService
|
from .data import WyomingService
|
||||||
|
from .devices import SatelliteDevice
|
||||||
|
from .models import DomainDataItem
|
||||||
|
from .satellite import WyomingSatellite
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
SATELLITE_PLATFORMS = [Platform.BINARY_SENSOR, Platform.SELECT, Platform.SWITCH]
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ATTR_SPEAKER",
|
"ATTR_SPEAKER",
|
||||||
"DOMAIN",
|
"DOMAIN",
|
||||||
|
"async_setup_entry",
|
||||||
|
"async_unload_entry",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -25,24 +34,72 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
if service is None:
|
if service is None:
|
||||||
raise ConfigEntryNotReady("Unable to connect")
|
raise ConfigEntryNotReady("Unable to connect")
|
||||||
|
|
||||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = service
|
item = DomainDataItem(service=service)
|
||||||
|
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = item
|
||||||
|
|
||||||
await hass.config_entries.async_forward_entry_setups(
|
await hass.config_entries.async_forward_entry_setups(entry, service.platforms)
|
||||||
entry,
|
entry.async_on_unload(entry.add_update_listener(update_listener))
|
||||||
service.platforms,
|
|
||||||
)
|
if (satellite_info := service.info.satellite) is not None:
|
||||||
|
# Create satellite device, etc.
|
||||||
|
item.satellite = _make_satellite(hass, entry, service)
|
||||||
|
|
||||||
|
# Set up satellite sensors, switches, etc.
|
||||||
|
await hass.config_entries.async_forward_entry_setups(entry, SATELLITE_PLATFORMS)
|
||||||
|
|
||||||
|
# Start satellite communication
|
||||||
|
entry.async_create_background_task(
|
||||||
|
hass,
|
||||||
|
item.satellite.run(),
|
||||||
|
f"Satellite {satellite_info.name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
entry.async_on_unload(item.satellite.stop)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _make_satellite(
|
||||||
|
hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService
|
||||||
|
) -> WyomingSatellite:
|
||||||
|
"""Create Wyoming satellite/device from config entry and Wyoming service."""
|
||||||
|
satellite_info = service.info.satellite
|
||||||
|
assert satellite_info is not None
|
||||||
|
|
||||||
|
dev_reg = dr.async_get(hass)
|
||||||
|
|
||||||
|
# Use config entry id since only one satellite per entry is supported
|
||||||
|
satellite_id = config_entry.entry_id
|
||||||
|
|
||||||
|
device = dev_reg.async_get_or_create(
|
||||||
|
config_entry_id=config_entry.entry_id,
|
||||||
|
identifiers={(DOMAIN, satellite_id)},
|
||||||
|
name=satellite_info.name,
|
||||||
|
suggested_area=satellite_info.area,
|
||||||
|
)
|
||||||
|
|
||||||
|
satellite_device = SatelliteDevice(
|
||||||
|
satellite_id=satellite_id,
|
||||||
|
device_id=device.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return WyomingSatellite(hass, service, satellite_device)
|
||||||
|
|
||||||
|
|
||||||
|
async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
|
||||||
|
"""Handle options update."""
|
||||||
|
await hass.config_entries.async_reload(entry.entry_id)
|
||||||
|
|
||||||
|
|
||||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||||
"""Unload Wyoming."""
|
"""Unload Wyoming."""
|
||||||
service: WyomingService = hass.data[DOMAIN][entry.entry_id]
|
item: DomainDataItem = hass.data[DOMAIN][entry.entry_id]
|
||||||
|
|
||||||
unload_ok = await hass.config_entries.async_unload_platforms(
|
platforms = list(item.service.platforms)
|
||||||
entry,
|
if item.satellite is not None:
|
||||||
service.platforms,
|
platforms += SATELLITE_PLATFORMS
|
||||||
)
|
|
||||||
|
unload_ok = await hass.config_entries.async_unload_platforms(entry, platforms)
|
||||||
if unload_ok:
|
if unload_ok:
|
||||||
del hass.data[DOMAIN][entry.entry_id]
|
del hass.data[DOMAIN][entry.entry_id]
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,55 @@
|
||||||
|
"""Binary sensor for Wyoming."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from homeassistant.components.binary_sensor import (
|
||||||
|
BinarySensorEntity,
|
||||||
|
BinarySensorEntityDescription,
|
||||||
|
)
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
from .entity import WyomingSatelliteEntity
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .models import DomainDataItem
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up binary sensor entities."""
|
||||||
|
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||||
|
|
||||||
|
# Setup is only forwarded for satellites
|
||||||
|
assert item.satellite is not None
|
||||||
|
|
||||||
|
async_add_entities([WyomingSatelliteAssistInProgress(item.satellite.device)])
|
||||||
|
|
||||||
|
|
||||||
|
class WyomingSatelliteAssistInProgress(WyomingSatelliteEntity, BinarySensorEntity):
|
||||||
|
"""Entity to represent Assist is in progress for satellite."""
|
||||||
|
|
||||||
|
entity_description = BinarySensorEntityDescription(
|
||||||
|
key="assist_in_progress",
|
||||||
|
translation_key="assist_in_progress",
|
||||||
|
)
|
||||||
|
_attr_is_on = False
|
||||||
|
|
||||||
|
async def async_added_to_hass(self) -> None:
|
||||||
|
"""Call when entity about to be added to hass."""
|
||||||
|
await super().async_added_to_hass()
|
||||||
|
|
||||||
|
self._device.set_is_active_listener(self._is_active_changed)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _is_active_changed(self) -> None:
|
||||||
|
"""Call when active state changed."""
|
||||||
|
self._attr_is_on = self._device.is_active
|
||||||
|
self.async_write_ha_state()
|
|
@ -1,19 +1,22 @@
|
||||||
"""Config flow for Wyoming integration."""
|
"""Config flow for Wyoming integration."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
from homeassistant.components.hassio import HassioServiceInfo
|
from homeassistant.components import hassio, zeroconf
|
||||||
from homeassistant.const import CONF_HOST, CONF_PORT
|
from homeassistant.const import CONF_HOST, CONF_NAME, CONF_PORT
|
||||||
from homeassistant.data_entry_flow import FlowResult
|
from homeassistant.data_entry_flow import FlowResult
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
from .data import WyomingService
|
from .data import WyomingService
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger()
|
||||||
|
|
||||||
STEP_USER_DATA_SCHEMA = vol.Schema(
|
STEP_USER_DATA_SCHEMA = vol.Schema(
|
||||||
{
|
{
|
||||||
vol.Required(CONF_HOST): str,
|
vol.Required(CONF_HOST): str,
|
||||||
|
@ -27,7 +30,9 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||||
|
|
||||||
VERSION = 1
|
VERSION = 1
|
||||||
|
|
||||||
_hassio_discovery: HassioServiceInfo
|
_hassio_discovery: hassio.HassioServiceInfo
|
||||||
|
_service: WyomingService | None = None
|
||||||
|
_name: str | None = None
|
||||||
|
|
||||||
async def async_step_user(
|
async def async_step_user(
|
||||||
self, user_input: dict[str, Any] | None = None
|
self, user_input: dict[str, Any] | None = None
|
||||||
|
@ -50,27 +55,14 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||||
errors={"base": "cannot_connect"},
|
errors={"base": "cannot_connect"},
|
||||||
)
|
)
|
||||||
|
|
||||||
# ASR = automated speech recognition (speech-to-text)
|
if name := service.get_name():
|
||||||
asr_installed = [asr for asr in service.info.asr if asr.installed]
|
return self.async_create_entry(title=name, data=user_input)
|
||||||
|
|
||||||
# TTS = text-to-speech
|
return self.async_abort(reason="no_services")
|
||||||
tts_installed = [tts for tts in service.info.tts if tts.installed]
|
|
||||||
|
|
||||||
# wake-word-detection
|
async def async_step_hassio(
|
||||||
wake_installed = [wake for wake in service.info.wake if wake.installed]
|
self, discovery_info: hassio.HassioServiceInfo
|
||||||
|
) -> FlowResult:
|
||||||
if asr_installed:
|
|
||||||
name = asr_installed[0].name
|
|
||||||
elif tts_installed:
|
|
||||||
name = tts_installed[0].name
|
|
||||||
elif wake_installed:
|
|
||||||
name = wake_installed[0].name
|
|
||||||
else:
|
|
||||||
return self.async_abort(reason="no_services")
|
|
||||||
|
|
||||||
return self.async_create_entry(title=name, data=user_input)
|
|
||||||
|
|
||||||
async def async_step_hassio(self, discovery_info: HassioServiceInfo) -> FlowResult:
|
|
||||||
"""Handle Supervisor add-on discovery."""
|
"""Handle Supervisor add-on discovery."""
|
||||||
await self.async_set_unique_id(discovery_info.uuid)
|
await self.async_set_unique_id(discovery_info.uuid)
|
||||||
self._abort_if_unique_id_configured()
|
self._abort_if_unique_id_configured()
|
||||||
|
@ -93,11 +85,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
uri = urlparse(self._hassio_discovery.config["uri"])
|
uri = urlparse(self._hassio_discovery.config["uri"])
|
||||||
if service := await WyomingService.create(uri.hostname, uri.port):
|
if service := await WyomingService.create(uri.hostname, uri.port):
|
||||||
if (
|
if not service.has_services():
|
||||||
not any(asr for asr in service.info.asr if asr.installed)
|
|
||||||
and not any(tts for tts in service.info.tts if tts.installed)
|
|
||||||
and not any(wake for wake in service.info.wake if wake.installed)
|
|
||||||
):
|
|
||||||
return self.async_abort(reason="no_services")
|
return self.async_abort(reason="no_services")
|
||||||
|
|
||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
|
@ -112,3 +100,52 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||||
description_placeholders={"addon": self._hassio_discovery.name},
|
description_placeholders={"addon": self._hassio_discovery.name},
|
||||||
errors=errors,
|
errors=errors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def async_step_zeroconf(
|
||||||
|
self, discovery_info: zeroconf.ZeroconfServiceInfo
|
||||||
|
) -> FlowResult:
|
||||||
|
"""Handle zeroconf discovery."""
|
||||||
|
_LOGGER.debug("Discovery info: %s", discovery_info)
|
||||||
|
if discovery_info.port is None:
|
||||||
|
return self.async_abort(reason="no_port")
|
||||||
|
|
||||||
|
service = await WyomingService.create(discovery_info.host, discovery_info.port)
|
||||||
|
if (service is None) or (not (name := service.get_name())):
|
||||||
|
# No supported services
|
||||||
|
return self.async_abort(reason="no_services")
|
||||||
|
|
||||||
|
self._name = name
|
||||||
|
|
||||||
|
# Use zeroconf name + service name as unique id.
|
||||||
|
# The satellite will use its own MAC as the zeroconf name by default.
|
||||||
|
unique_id = f"{discovery_info.name}_{self._name}"
|
||||||
|
await self.async_set_unique_id(unique_id)
|
||||||
|
self._abort_if_unique_id_configured()
|
||||||
|
|
||||||
|
self.context[CONF_NAME] = self._name
|
||||||
|
self.context["title_placeholders"] = {"name": self._name}
|
||||||
|
|
||||||
|
self._service = service
|
||||||
|
return await self.async_step_zeroconf_confirm()
|
||||||
|
|
||||||
|
async def async_step_zeroconf_confirm(
|
||||||
|
self, user_input: dict[str, Any] | None = None
|
||||||
|
) -> FlowResult:
|
||||||
|
"""Handle a flow initiated by zeroconf."""
|
||||||
|
assert self._service is not None
|
||||||
|
assert self._name is not None
|
||||||
|
|
||||||
|
if user_input is None:
|
||||||
|
return self.async_show_form(
|
||||||
|
step_id="zeroconf_confirm",
|
||||||
|
description_placeholders={"name": self._name},
|
||||||
|
errors={},
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.async_create_entry(
|
||||||
|
title=self._name,
|
||||||
|
data={
|
||||||
|
CONF_HOST: self._service.host,
|
||||||
|
CONF_PORT: self._service.port,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from wyoming.client import AsyncTcpClient
|
from wyoming.client import AsyncTcpClient
|
||||||
from wyoming.info import Describe, Info
|
from wyoming.info import Describe, Info, Satellite
|
||||||
|
|
||||||
from homeassistant.const import Platform
|
from homeassistant.const import Platform
|
||||||
|
|
||||||
|
@ -32,6 +32,43 @@ class WyomingService:
|
||||||
platforms.append(Platform.WAKE_WORD)
|
platforms.append(Platform.WAKE_WORD)
|
||||||
self.platforms = platforms
|
self.platforms = platforms
|
||||||
|
|
||||||
|
def has_services(self) -> bool:
|
||||||
|
"""Return True if services are installed that Home Assistant can use."""
|
||||||
|
return (
|
||||||
|
any(asr for asr in self.info.asr if asr.installed)
|
||||||
|
or any(tts for tts in self.info.tts if tts.installed)
|
||||||
|
or any(wake for wake in self.info.wake if wake.installed)
|
||||||
|
or ((self.info.satellite is not None) and self.info.satellite.installed)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_name(self) -> str | None:
|
||||||
|
"""Return name of first installed usable service."""
|
||||||
|
# ASR = automated speech recognition (speech-to-text)
|
||||||
|
asr_installed = [asr for asr in self.info.asr if asr.installed]
|
||||||
|
if asr_installed:
|
||||||
|
return asr_installed[0].name
|
||||||
|
|
||||||
|
# TTS = text-to-speech
|
||||||
|
tts_installed = [tts for tts in self.info.tts if tts.installed]
|
||||||
|
if tts_installed:
|
||||||
|
return tts_installed[0].name
|
||||||
|
|
||||||
|
# wake-word-detection
|
||||||
|
wake_installed = [wake for wake in self.info.wake if wake.installed]
|
||||||
|
if wake_installed:
|
||||||
|
return wake_installed[0].name
|
||||||
|
|
||||||
|
# satellite
|
||||||
|
satellite_installed: Satellite | None = None
|
||||||
|
|
||||||
|
if (self.info.satellite is not None) and self.info.satellite.installed:
|
||||||
|
satellite_installed = self.info.satellite
|
||||||
|
|
||||||
|
if satellite_installed:
|
||||||
|
return satellite_installed.name
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(cls, host: str, port: int) -> WyomingService | None:
|
async def create(cls, host: str, port: int) -> WyomingService | None:
|
||||||
"""Create a Wyoming service."""
|
"""Create a Wyoming service."""
|
||||||
|
|
|
@ -0,0 +1,85 @@
|
||||||
|
"""Class to manage satellite devices."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
from homeassistant.helpers import entity_registry as er
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SatelliteDevice:
|
||||||
|
"""Class to store device."""
|
||||||
|
|
||||||
|
satellite_id: str
|
||||||
|
device_id: str
|
||||||
|
is_active: bool = False
|
||||||
|
is_enabled: bool = True
|
||||||
|
pipeline_name: str | None = None
|
||||||
|
|
||||||
|
_is_active_listener: Callable[[], None] | None = None
|
||||||
|
_is_enabled_listener: Callable[[], None] | None = None
|
||||||
|
_pipeline_listener: Callable[[], None] | None = None
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def set_is_active(self, active: bool) -> None:
|
||||||
|
"""Set active state."""
|
||||||
|
if active != self.is_active:
|
||||||
|
self.is_active = active
|
||||||
|
if self._is_active_listener is not None:
|
||||||
|
self._is_active_listener()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def set_is_enabled(self, enabled: bool) -> None:
|
||||||
|
"""Set enabled state."""
|
||||||
|
if enabled != self.is_enabled:
|
||||||
|
self.is_enabled = enabled
|
||||||
|
if self._is_enabled_listener is not None:
|
||||||
|
self._is_enabled_listener()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def set_pipeline_name(self, pipeline_name: str) -> None:
|
||||||
|
"""Inform listeners that pipeline selection has changed."""
|
||||||
|
if pipeline_name != self.pipeline_name:
|
||||||
|
self.pipeline_name = pipeline_name
|
||||||
|
if self._pipeline_listener is not None:
|
||||||
|
self._pipeline_listener()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def set_is_active_listener(self, is_active_listener: Callable[[], None]) -> None:
|
||||||
|
"""Listen for updates to is_active."""
|
||||||
|
self._is_active_listener = is_active_listener
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def set_is_enabled_listener(self, is_enabled_listener: Callable[[], None]) -> None:
|
||||||
|
"""Listen for updates to is_enabled."""
|
||||||
|
self._is_enabled_listener = is_enabled_listener
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def set_pipeline_listener(self, pipeline_listener: Callable[[], None]) -> None:
|
||||||
|
"""Listen for updates to pipeline."""
|
||||||
|
self._pipeline_listener = pipeline_listener
|
||||||
|
|
||||||
|
def get_assist_in_progress_entity_id(self, hass: HomeAssistant) -> str | None:
|
||||||
|
"""Return entity id for assist in progress binary sensor."""
|
||||||
|
ent_reg = er.async_get(hass)
|
||||||
|
return ent_reg.async_get_entity_id(
|
||||||
|
"binary_sensor", DOMAIN, f"{self.satellite_id}-assist_in_progress"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_satellite_enabled_entity_id(self, hass: HomeAssistant) -> str | None:
|
||||||
|
"""Return entity id for satellite enabled switch."""
|
||||||
|
ent_reg = er.async_get(hass)
|
||||||
|
return ent_reg.async_get_entity_id(
|
||||||
|
"switch", DOMAIN, f"{self.satellite_id}-satellite_enabled"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_pipeline_entity_id(self, hass: HomeAssistant) -> str | None:
|
||||||
|
"""Return entity id for pipeline select."""
|
||||||
|
ent_reg = er.async_get(hass)
|
||||||
|
return ent_reg.async_get_entity_id(
|
||||||
|
"select", DOMAIN, f"{self.satellite_id}-pipeline"
|
||||||
|
)
|
|
@ -0,0 +1,24 @@
|
||||||
|
"""Wyoming entities."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from homeassistant.helpers import entity
|
||||||
|
from homeassistant.helpers.device_registry import DeviceInfo
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
from .satellite import SatelliteDevice
|
||||||
|
|
||||||
|
|
||||||
|
class WyomingSatelliteEntity(entity.Entity):
|
||||||
|
"""Wyoming satellite entity."""
|
||||||
|
|
||||||
|
_attr_has_entity_name = True
|
||||||
|
_attr_should_poll = False
|
||||||
|
|
||||||
|
def __init__(self, device: SatelliteDevice) -> None:
|
||||||
|
"""Initialize entity."""
|
||||||
|
self._device = device
|
||||||
|
self._attr_unique_id = f"{device.satellite_id}-{self.entity_description.key}"
|
||||||
|
self._attr_device_info = DeviceInfo(
|
||||||
|
identifiers={(DOMAIN, device.satellite_id)},
|
||||||
|
)
|
|
@ -3,7 +3,9 @@
|
||||||
"name": "Wyoming Protocol",
|
"name": "Wyoming Protocol",
|
||||||
"codeowners": ["@balloob", "@synesthesiam"],
|
"codeowners": ["@balloob", "@synesthesiam"],
|
||||||
"config_flow": true,
|
"config_flow": true,
|
||||||
|
"dependencies": ["assist_pipeline"],
|
||||||
"documentation": "https://www.home-assistant.io/integrations/wyoming",
|
"documentation": "https://www.home-assistant.io/integrations/wyoming",
|
||||||
"iot_class": "local_push",
|
"iot_class": "local_push",
|
||||||
"requirements": ["wyoming==1.2.0"]
|
"requirements": ["wyoming==1.3.0"],
|
||||||
|
"zeroconf": ["_wyoming._tcp.local."]
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
"""Models for wyoming."""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from .data import WyomingService
|
||||||
|
from .satellite import WyomingSatellite
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DomainDataItem:
|
||||||
|
"""Domain data item."""
|
||||||
|
|
||||||
|
service: WyomingService
|
||||||
|
satellite: WyomingSatellite | None = None
|
|
@ -0,0 +1,380 @@
|
||||||
|
"""Support for Wyoming satellite services."""
|
||||||
|
import asyncio
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from typing import Final
|
||||||
|
import wave
|
||||||
|
|
||||||
|
from wyoming.asr import Transcribe, Transcript
|
||||||
|
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop
|
||||||
|
from wyoming.client import AsyncTcpClient
|
||||||
|
from wyoming.pipeline import PipelineStage, RunPipeline
|
||||||
|
from wyoming.satellite import RunSatellite
|
||||||
|
from wyoming.tts import Synthesize, SynthesizeVoice
|
||||||
|
from wyoming.vad import VoiceStarted, VoiceStopped
|
||||||
|
from wyoming.wake import Detect, Detection
|
||||||
|
|
||||||
|
from homeassistant.components import assist_pipeline, stt, tts
|
||||||
|
from homeassistant.components.assist_pipeline import select as pipeline_select
|
||||||
|
from homeassistant.core import Context, HomeAssistant
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
from .data import WyomingService
|
||||||
|
from .devices import SatelliteDevice
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger()
|
||||||
|
|
||||||
|
_SAMPLES_PER_CHUNK: Final = 1024
|
||||||
|
_RECONNECT_SECONDS: Final = 10
|
||||||
|
_RESTART_SECONDS: Final = 3
|
||||||
|
|
||||||
|
# Wyoming stage -> Assist stage
|
||||||
|
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
|
||||||
|
PipelineStage.WAKE: assist_pipeline.PipelineStage.WAKE_WORD,
|
||||||
|
PipelineStage.ASR: assist_pipeline.PipelineStage.STT,
|
||||||
|
PipelineStage.HANDLE: assist_pipeline.PipelineStage.INTENT,
|
||||||
|
PipelineStage.TTS: assist_pipeline.PipelineStage.TTS,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class WyomingSatellite:
|
||||||
|
"""Remove voice satellite running the Wyoming protocol."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, hass: HomeAssistant, service: WyomingService, device: SatelliteDevice
|
||||||
|
) -> None:
|
||||||
|
"""Initialize satellite."""
|
||||||
|
self.hass = hass
|
||||||
|
self.service = service
|
||||||
|
self.device = device
|
||||||
|
self.is_enabled = True
|
||||||
|
self.is_running = True
|
||||||
|
|
||||||
|
self._client: AsyncTcpClient | None = None
|
||||||
|
self._chunk_converter = AudioChunkConverter(rate=16000, width=2, channels=1)
|
||||||
|
self._is_pipeline_running = False
|
||||||
|
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||||
|
self._pipeline_id: str | None = None
|
||||||
|
self._enabled_changed_event = asyncio.Event()
|
||||||
|
|
||||||
|
self.device.set_is_enabled_listener(self._enabled_changed)
|
||||||
|
self.device.set_pipeline_listener(self._pipeline_changed)
|
||||||
|
|
||||||
|
async def run(self) -> None:
|
||||||
|
"""Run and maintain a connection to satellite."""
|
||||||
|
_LOGGER.debug("Running satellite task")
|
||||||
|
|
||||||
|
try:
|
||||||
|
while self.is_running:
|
||||||
|
try:
|
||||||
|
# Check if satellite has been disabled
|
||||||
|
if not self.device.is_enabled:
|
||||||
|
await self.on_disabled()
|
||||||
|
if not self.is_running:
|
||||||
|
# Satellite was stopped while waiting to be enabled
|
||||||
|
break
|
||||||
|
|
||||||
|
# Connect and run pipeline loop
|
||||||
|
await self._run_once()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
|
await self.on_restart()
|
||||||
|
finally:
|
||||||
|
# Ensure sensor is off
|
||||||
|
self.device.set_is_active(False)
|
||||||
|
|
||||||
|
await self.on_stopped()
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
"""Signal satellite task to stop running."""
|
||||||
|
self.is_running = False
|
||||||
|
|
||||||
|
# Unblock waiting for enabled
|
||||||
|
self._enabled_changed_event.set()
|
||||||
|
|
||||||
|
async def on_restart(self) -> None:
|
||||||
|
"""Block until pipeline loop will be restarted."""
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Unexpected error running satellite. Restarting in %s second(s)",
|
||||||
|
_RECONNECT_SECONDS,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(_RESTART_SECONDS)
|
||||||
|
|
||||||
|
async def on_reconnect(self) -> None:
|
||||||
|
"""Block until a reconnection attempt should be made."""
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Failed to connect to satellite. Reconnecting in %s second(s)",
|
||||||
|
_RECONNECT_SECONDS,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(_RECONNECT_SECONDS)
|
||||||
|
|
||||||
|
async def on_disabled(self) -> None:
|
||||||
|
"""Block until device may be enabled again."""
|
||||||
|
await self._enabled_changed_event.wait()
|
||||||
|
|
||||||
|
async def on_stopped(self) -> None:
|
||||||
|
"""Run when run() has fully stopped."""
|
||||||
|
_LOGGER.debug("Satellite task stopped")
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _enabled_changed(self) -> None:
|
||||||
|
"""Run when device enabled status changes."""
|
||||||
|
|
||||||
|
if not self.device.is_enabled:
|
||||||
|
# Cancel any running pipeline
|
||||||
|
self._audio_queue.put_nowait(None)
|
||||||
|
|
||||||
|
self._enabled_changed_event.set()
|
||||||
|
|
||||||
|
def _pipeline_changed(self) -> None:
|
||||||
|
"""Run when device pipeline changes."""
|
||||||
|
|
||||||
|
# Cancel any running pipeline
|
||||||
|
self._audio_queue.put_nowait(None)
|
||||||
|
|
||||||
|
async def _run_once(self) -> None:
|
||||||
|
"""Run pipelines until an error occurs."""
|
||||||
|
self.device.set_is_active(False)
|
||||||
|
|
||||||
|
while self.is_running and self.is_enabled:
|
||||||
|
try:
|
||||||
|
await self._connect()
|
||||||
|
break
|
||||||
|
except ConnectionError:
|
||||||
|
await self.on_reconnect()
|
||||||
|
|
||||||
|
assert self._client is not None
|
||||||
|
_LOGGER.debug("Connected to satellite")
|
||||||
|
|
||||||
|
if (not self.is_running) or (not self.is_enabled):
|
||||||
|
# Run was cancelled or satellite was disabled during connection
|
||||||
|
return
|
||||||
|
|
||||||
|
# Tell satellite that we're ready
|
||||||
|
await self._client.write_event(RunSatellite().event())
|
||||||
|
|
||||||
|
# Wait until we get RunPipeline event
|
||||||
|
run_pipeline: RunPipeline | None = None
|
||||||
|
while self.is_running and self.is_enabled:
|
||||||
|
run_event = await self._client.read_event()
|
||||||
|
if run_event is None:
|
||||||
|
raise ConnectionResetError("Satellite disconnected")
|
||||||
|
|
||||||
|
if RunPipeline.is_type(run_event.type):
|
||||||
|
run_pipeline = RunPipeline.from_event(run_event)
|
||||||
|
break
|
||||||
|
|
||||||
|
_LOGGER.debug("Unexpected event from satellite: %s", run_event)
|
||||||
|
|
||||||
|
assert run_pipeline is not None
|
||||||
|
_LOGGER.debug("Received run information: %s", run_pipeline)
|
||||||
|
|
||||||
|
if (not self.is_running) or (not self.is_enabled):
|
||||||
|
# Run was cancelled or satellite was disabled while waiting for
|
||||||
|
# RunPipeline event.
|
||||||
|
return
|
||||||
|
|
||||||
|
start_stage = _STAGES.get(run_pipeline.start_stage)
|
||||||
|
end_stage = _STAGES.get(run_pipeline.end_stage)
|
||||||
|
|
||||||
|
if start_stage is None:
|
||||||
|
raise ValueError(f"Invalid start stage: {start_stage}")
|
||||||
|
|
||||||
|
if end_stage is None:
|
||||||
|
raise ValueError(f"Invalid end stage: {end_stage}")
|
||||||
|
|
||||||
|
# Each loop is a pipeline run
|
||||||
|
while self.is_running and self.is_enabled:
|
||||||
|
# Use select to get pipeline each time in case it's changed
|
||||||
|
pipeline_id = pipeline_select.get_chosen_pipeline(
|
||||||
|
self.hass,
|
||||||
|
DOMAIN,
|
||||||
|
self.device.satellite_id,
|
||||||
|
)
|
||||||
|
pipeline = assist_pipeline.async_get_pipeline(self.hass, pipeline_id)
|
||||||
|
assert pipeline is not None
|
||||||
|
|
||||||
|
# We will push audio in through a queue
|
||||||
|
self._audio_queue = asyncio.Queue()
|
||||||
|
stt_stream = self._stt_stream()
|
||||||
|
|
||||||
|
# Start pipeline running
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Starting pipeline %s from %s to %s",
|
||||||
|
pipeline.name,
|
||||||
|
start_stage,
|
||||||
|
end_stage,
|
||||||
|
)
|
||||||
|
self._is_pipeline_running = True
|
||||||
|
_pipeline_task = asyncio.create_task(
|
||||||
|
assist_pipeline.async_pipeline_from_audio_stream(
|
||||||
|
self.hass,
|
||||||
|
context=Context(),
|
||||||
|
event_callback=self._event_callback,
|
||||||
|
stt_metadata=stt.SpeechMetadata(
|
||||||
|
language=pipeline.language,
|
||||||
|
format=stt.AudioFormats.WAV,
|
||||||
|
codec=stt.AudioCodecs.PCM,
|
||||||
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||||
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||||
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
|
),
|
||||||
|
stt_stream=stt_stream,
|
||||||
|
start_stage=start_stage,
|
||||||
|
end_stage=end_stage,
|
||||||
|
tts_audio_output="wav",
|
||||||
|
pipeline_id=pipeline_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run until pipeline is complete or cancelled with an empty audio chunk
|
||||||
|
while self._is_pipeline_running:
|
||||||
|
client_event = await self._client.read_event()
|
||||||
|
if client_event is None:
|
||||||
|
raise ConnectionResetError("Satellite disconnected")
|
||||||
|
|
||||||
|
if AudioChunk.is_type(client_event.type):
|
||||||
|
# Microphone audio
|
||||||
|
chunk = AudioChunk.from_event(client_event)
|
||||||
|
chunk = self._chunk_converter.convert(chunk)
|
||||||
|
self._audio_queue.put_nowait(chunk.audio)
|
||||||
|
else:
|
||||||
|
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
|
||||||
|
|
||||||
|
_LOGGER.debug("Pipeline finished")
|
||||||
|
|
||||||
|
def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None:
|
||||||
|
"""Translate pipeline events into Wyoming events."""
|
||||||
|
assert self._client is not None
|
||||||
|
|
||||||
|
if event.type == assist_pipeline.PipelineEventType.RUN_END:
|
||||||
|
# Pipeline run is complete
|
||||||
|
self._is_pipeline_running = False
|
||||||
|
self.device.set_is_active(False)
|
||||||
|
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
|
||||||
|
self.hass.add_job(self._client.write_event(Detect().event()))
|
||||||
|
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END:
|
||||||
|
# Wake word detection
|
||||||
|
self.device.set_is_active(True)
|
||||||
|
|
||||||
|
# Inform client of wake word detection
|
||||||
|
if event.data and (wake_word_output := event.data.get("wake_word_output")):
|
||||||
|
detection = Detection(
|
||||||
|
name=wake_word_output["wake_word_id"],
|
||||||
|
timestamp=wake_word_output.get("timestamp"),
|
||||||
|
)
|
||||||
|
self.hass.add_job(self._client.write_event(detection.event()))
|
||||||
|
elif event.type == assist_pipeline.PipelineEventType.STT_START:
|
||||||
|
# Speech-to-text
|
||||||
|
self.device.set_is_active(True)
|
||||||
|
|
||||||
|
if event.data:
|
||||||
|
self.hass.add_job(
|
||||||
|
self._client.write_event(
|
||||||
|
Transcribe(language=event.data["metadata"]["language"]).event()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START:
|
||||||
|
# User started speaking
|
||||||
|
if event.data:
|
||||||
|
self.hass.add_job(
|
||||||
|
self._client.write_event(
|
||||||
|
VoiceStarted(timestamp=event.data["timestamp"]).event()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END:
|
||||||
|
# User stopped speaking
|
||||||
|
if event.data:
|
||||||
|
self.hass.add_job(
|
||||||
|
self._client.write_event(
|
||||||
|
VoiceStopped(timestamp=event.data["timestamp"]).event()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif event.type == assist_pipeline.PipelineEventType.STT_END:
|
||||||
|
# Speech-to-text transcript
|
||||||
|
if event.data:
|
||||||
|
# Inform client of transript
|
||||||
|
stt_text = event.data["stt_output"]["text"]
|
||||||
|
self.hass.add_job(
|
||||||
|
self._client.write_event(Transcript(text=stt_text).event())
|
||||||
|
)
|
||||||
|
elif event.type == assist_pipeline.PipelineEventType.TTS_START:
|
||||||
|
# Text-to-speech text
|
||||||
|
if event.data:
|
||||||
|
# Inform client of text
|
||||||
|
self.hass.add_job(
|
||||||
|
self._client.write_event(
|
||||||
|
Synthesize(
|
||||||
|
text=event.data["tts_input"],
|
||||||
|
voice=SynthesizeVoice(
|
||||||
|
name=event.data.get("voice"),
|
||||||
|
language=event.data.get("language"),
|
||||||
|
),
|
||||||
|
).event()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||||
|
# TTS stream
|
||||||
|
if event.data and (tts_output := event.data["tts_output"]):
|
||||||
|
media_id = tts_output["media_id"]
|
||||||
|
self.hass.add_job(self._stream_tts(media_id))
|
||||||
|
|
||||||
|
async def _connect(self) -> None:
|
||||||
|
"""Connect to satellite over TCP."""
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Connecting to satellite at %s:%s", self.service.host, self.service.port
|
||||||
|
)
|
||||||
|
self._client = AsyncTcpClient(self.service.host, self.service.port)
|
||||||
|
await self._client.connect()
|
||||||
|
|
||||||
|
async def _stream_tts(self, media_id: str) -> None:
|
||||||
|
"""Stream TTS WAV audio to satellite in chunks."""
|
||||||
|
assert self._client is not None
|
||||||
|
|
||||||
|
extension, data = await tts.async_get_media_source_audio(self.hass, media_id)
|
||||||
|
if extension != "wav":
|
||||||
|
raise ValueError(f"Cannot stream audio format to satellite: {extension}")
|
||||||
|
|
||||||
|
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||||
|
sample_rate = wav_file.getframerate()
|
||||||
|
sample_width = wav_file.getsampwidth()
|
||||||
|
sample_channels = wav_file.getnchannels()
|
||||||
|
_LOGGER.debug("Streaming %s TTS sample(s)", wav_file.getnframes())
|
||||||
|
|
||||||
|
timestamp = 0
|
||||||
|
await self._client.write_event(
|
||||||
|
AudioStart(
|
||||||
|
rate=sample_rate,
|
||||||
|
width=sample_width,
|
||||||
|
channels=sample_channels,
|
||||||
|
timestamp=timestamp,
|
||||||
|
).event()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stream audio chunks
|
||||||
|
while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK):
|
||||||
|
chunk = AudioChunk(
|
||||||
|
rate=sample_rate,
|
||||||
|
width=sample_width,
|
||||||
|
channels=sample_channels,
|
||||||
|
audio=audio_bytes,
|
||||||
|
timestamp=timestamp,
|
||||||
|
)
|
||||||
|
await self._client.write_event(chunk.event())
|
||||||
|
timestamp += chunk.seconds
|
||||||
|
|
||||||
|
await self._client.write_event(AudioStop(timestamp=timestamp).event())
|
||||||
|
_LOGGER.debug("TTS streaming complete")
|
||||||
|
|
||||||
|
async def _stt_stream(self) -> AsyncGenerator[bytes, None]:
|
||||||
|
"""Yield audio chunks from a queue."""
|
||||||
|
is_first_chunk = True
|
||||||
|
while chunk := await self._audio_queue.get():
|
||||||
|
if is_first_chunk:
|
||||||
|
is_first_chunk = False
|
||||||
|
_LOGGER.debug("Receiving audio from satellite")
|
||||||
|
|
||||||
|
yield chunk
|
|
@ -0,0 +1,47 @@
|
||||||
|
"""Select entities for VoIP integration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
from .devices import SatelliteDevice
|
||||||
|
from .entity import WyomingSatelliteEntity
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .models import DomainDataItem
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up VoIP switch entities."""
|
||||||
|
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||||
|
|
||||||
|
# Setup is only forwarded for satellites
|
||||||
|
assert item.satellite is not None
|
||||||
|
|
||||||
|
async_add_entities([WyomingSatellitePipelineSelect(hass, item.satellite.device)])
|
||||||
|
|
||||||
|
|
||||||
|
class WyomingSatellitePipelineSelect(WyomingSatelliteEntity, AssistPipelineSelect):
|
||||||
|
"""Pipeline selector for Wyoming satellites."""
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant, device: SatelliteDevice) -> None:
|
||||||
|
"""Initialize a pipeline selector."""
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
WyomingSatelliteEntity.__init__(self, device)
|
||||||
|
AssistPipelineSelect.__init__(self, hass, device.satellite_id)
|
||||||
|
|
||||||
|
async def async_select_option(self, option: str) -> None:
|
||||||
|
"""Select an option."""
|
||||||
|
await super().async_select_option(option)
|
||||||
|
self.device.set_pipeline_name(option)
|
|
@ -9,6 +9,10 @@
|
||||||
},
|
},
|
||||||
"hassio_confirm": {
|
"hassio_confirm": {
|
||||||
"description": "Do you want to configure Home Assistant to connect to the Wyoming service provided by the add-on: {addon}?"
|
"description": "Do you want to configure Home Assistant to connect to the Wyoming service provided by the add-on: {addon}?"
|
||||||
|
},
|
||||||
|
"zeroconf_confirm": {
|
||||||
|
"description": "Do you want to configure Home Assistant to connect to the Wyoming service {name}?",
|
||||||
|
"title": "Discovered Wyoming service"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"error": {
|
"error": {
|
||||||
|
@ -16,7 +20,31 @@
|
||||||
},
|
},
|
||||||
"abort": {
|
"abort": {
|
||||||
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]",
|
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]",
|
||||||
"no_services": "No services found at endpoint"
|
"no_services": "No services found at endpoint",
|
||||||
|
"no_port": "No port for endpoint"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"entity": {
|
||||||
|
"binary_sensor": {
|
||||||
|
"assist_in_progress": {
|
||||||
|
"name": "[%key:component::assist_pipeline::entity::binary_sensor::assist_in_progress::name%]"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"select": {
|
||||||
|
"pipeline": {
|
||||||
|
"name": "[%key:component::assist_pipeline::entity::select::pipeline::name%]",
|
||||||
|
"state": {
|
||||||
|
"preferred": "[%key:component::assist_pipeline::entity::select::pipeline::state::preferred%]"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"noise_suppression": {
|
||||||
|
"name": "Noise suppression"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"switch": {
|
||||||
|
"satellite_enabled": {
|
||||||
|
"name": "Satellite enabled"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,6 +14,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from .const import DOMAIN, SAMPLE_CHANNELS, SAMPLE_RATE, SAMPLE_WIDTH
|
from .const import DOMAIN, SAMPLE_CHANNELS, SAMPLE_RATE, SAMPLE_WIDTH
|
||||||
from .data import WyomingService
|
from .data import WyomingService
|
||||||
from .error import WyomingError
|
from .error import WyomingError
|
||||||
|
from .models import DomainDataItem
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -24,10 +25,10 @@ async def async_setup_entry(
|
||||||
async_add_entities: AddEntitiesCallback,
|
async_add_entities: AddEntitiesCallback,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up Wyoming speech-to-text."""
|
"""Set up Wyoming speech-to-text."""
|
||||||
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
|
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||||
async_add_entities(
|
async_add_entities(
|
||||||
[
|
[
|
||||||
WyomingSttProvider(config_entry, service),
|
WyomingSttProvider(config_entry, item.service),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,65 @@
|
||||||
|
"""Wyoming switch entities."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from homeassistant.components.switch import SwitchEntity, SwitchEntityDescription
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.const import STATE_ON, EntityCategory
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import restore_state
|
||||||
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
from .entity import WyomingSatelliteEntity
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .models import DomainDataItem
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
async_add_entities: AddEntitiesCallback,
|
||||||
|
) -> None:
|
||||||
|
"""Set up VoIP switch entities."""
|
||||||
|
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||||
|
|
||||||
|
# Setup is only forwarded for satellites
|
||||||
|
assert item.satellite is not None
|
||||||
|
|
||||||
|
async_add_entities([WyomingSatelliteEnabledSwitch(item.satellite.device)])
|
||||||
|
|
||||||
|
|
||||||
|
class WyomingSatelliteEnabledSwitch(
|
||||||
|
WyomingSatelliteEntity, restore_state.RestoreEntity, SwitchEntity
|
||||||
|
):
|
||||||
|
"""Entity to represent if satellite is enabled."""
|
||||||
|
|
||||||
|
entity_description = SwitchEntityDescription(
|
||||||
|
key="satellite_enabled",
|
||||||
|
translation_key="satellite_enabled",
|
||||||
|
entity_category=EntityCategory.CONFIG,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_added_to_hass(self) -> None:
|
||||||
|
"""Call when entity about to be added to hass."""
|
||||||
|
await super().async_added_to_hass()
|
||||||
|
|
||||||
|
state = await self.async_get_last_state()
|
||||||
|
|
||||||
|
# Default to on
|
||||||
|
self._attr_is_on = (state is None) or (state.state == STATE_ON)
|
||||||
|
|
||||||
|
async def async_turn_on(self, **kwargs: Any) -> None:
|
||||||
|
"""Turn on."""
|
||||||
|
self._attr_is_on = True
|
||||||
|
self.async_write_ha_state()
|
||||||
|
self._device.set_is_enabled(True)
|
||||||
|
|
||||||
|
async def async_turn_off(self, **kwargs: Any) -> None:
|
||||||
|
"""Turn off."""
|
||||||
|
self._attr_is_on = False
|
||||||
|
self.async_write_ha_state()
|
||||||
|
self._device.set_is_enabled(False)
|
|
@ -16,6 +16,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from .const import ATTR_SPEAKER, DOMAIN
|
from .const import ATTR_SPEAKER, DOMAIN
|
||||||
from .data import WyomingService
|
from .data import WyomingService
|
||||||
from .error import WyomingError
|
from .error import WyomingError
|
||||||
|
from .models import DomainDataItem
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -26,10 +27,10 @@ async def async_setup_entry(
|
||||||
async_add_entities: AddEntitiesCallback,
|
async_add_entities: AddEntitiesCallback,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up Wyoming speech-to-text."""
|
"""Set up Wyoming speech-to-text."""
|
||||||
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
|
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||||
async_add_entities(
|
async_add_entities(
|
||||||
[
|
[
|
||||||
WyomingTtsProvider(config_entry, service),
|
WyomingTtsProvider(config_entry, item.service),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
from .data import WyomingService, load_wyoming_info
|
from .data import WyomingService, load_wyoming_info
|
||||||
from .error import WyomingError
|
from .error import WyomingError
|
||||||
|
from .models import DomainDataItem
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -25,10 +26,10 @@ async def async_setup_entry(
|
||||||
async_add_entities: AddEntitiesCallback,
|
async_add_entities: AddEntitiesCallback,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up Wyoming wake-word-detection."""
|
"""Set up Wyoming wake-word-detection."""
|
||||||
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
|
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||||
async_add_entities(
|
async_add_entities(
|
||||||
[
|
[
|
||||||
WyomingWakeWordProvider(hass, config_entry, service),
|
WyomingWakeWordProvider(hass, config_entry, item.service),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -715,6 +715,11 @@ ZEROCONF = {
|
||||||
"domain": "wled",
|
"domain": "wled",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
|
"_wyoming._tcp.local.": [
|
||||||
|
{
|
||||||
|
"domain": "wyoming",
|
||||||
|
},
|
||||||
|
],
|
||||||
"_xbmc-jsonrpc-h._tcp.local.": [
|
"_xbmc-jsonrpc-h._tcp.local.": [
|
||||||
{
|
{
|
||||||
"domain": "kodi",
|
"domain": "kodi",
|
||||||
|
|
|
@ -2750,7 +2750,7 @@ wled==0.17.0
|
||||||
wolf-smartset==0.1.11
|
wolf-smartset==0.1.11
|
||||||
|
|
||||||
# homeassistant.components.wyoming
|
# homeassistant.components.wyoming
|
||||||
wyoming==1.2.0
|
wyoming==1.3.0
|
||||||
|
|
||||||
# homeassistant.components.xbox
|
# homeassistant.components.xbox
|
||||||
xbox-webapi==2.0.11
|
xbox-webapi==2.0.11
|
||||||
|
|
|
@ -2054,7 +2054,7 @@ wled==0.17.0
|
||||||
wolf-smartset==0.1.11
|
wolf-smartset==0.1.11
|
||||||
|
|
||||||
# homeassistant.components.wyoming
|
# homeassistant.components.wyoming
|
||||||
wyoming==1.2.0
|
wyoming==1.3.0
|
||||||
|
|
||||||
# homeassistant.components.xbox
|
# homeassistant.components.xbox
|
||||||
xbox-webapi==2.0.11
|
xbox-webapi==2.0.11
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
"""Tests for the Wyoming integration."""
|
"""Tests for the Wyoming integration."""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
from wyoming.event import Event
|
||||||
from wyoming.info import (
|
from wyoming.info import (
|
||||||
AsrModel,
|
AsrModel,
|
||||||
AsrProgram,
|
AsrProgram,
|
||||||
Attribution,
|
Attribution,
|
||||||
Info,
|
Info,
|
||||||
|
Satellite,
|
||||||
TtsProgram,
|
TtsProgram,
|
||||||
TtsVoice,
|
TtsVoice,
|
||||||
TtsVoiceSpeaker,
|
TtsVoiceSpeaker,
|
||||||
|
@ -72,24 +74,36 @@ WAKE_WORD_INFO = Info(
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
SATELLITE_INFO = Info(
|
||||||
|
satellite=Satellite(
|
||||||
|
name="Test Satellite",
|
||||||
|
description="Test Satellite",
|
||||||
|
installed=True,
|
||||||
|
attribution=TEST_ATTR,
|
||||||
|
area="Office",
|
||||||
|
)
|
||||||
|
)
|
||||||
EMPTY_INFO = Info()
|
EMPTY_INFO = Info()
|
||||||
|
|
||||||
|
|
||||||
class MockAsyncTcpClient:
|
class MockAsyncTcpClient:
|
||||||
"""Mock AsyncTcpClient."""
|
"""Mock AsyncTcpClient."""
|
||||||
|
|
||||||
def __init__(self, responses) -> None:
|
def __init__(self, responses: list[Event]) -> None:
|
||||||
"""Initialize."""
|
"""Initialize."""
|
||||||
self.host = None
|
self.host: str | None = None
|
||||||
self.port = None
|
self.port: int | None = None
|
||||||
self.written = []
|
self.written: list[Event] = []
|
||||||
self.responses = responses
|
self.responses = responses
|
||||||
|
|
||||||
async def write_event(self, event):
|
async def connect(self) -> None:
|
||||||
|
"""Connect."""
|
||||||
|
|
||||||
|
async def write_event(self, event: Event):
|
||||||
"""Send."""
|
"""Send."""
|
||||||
self.written.append(event)
|
self.written.append(event)
|
||||||
|
|
||||||
async def read_event(self):
|
async def read_event(self) -> Event | None:
|
||||||
"""Receive."""
|
"""Receive."""
|
||||||
await asyncio.sleep(0) # force context switch
|
await asyncio.sleep(0) # force context switch
|
||||||
|
|
||||||
|
@ -105,7 +119,7 @@ class MockAsyncTcpClient:
|
||||||
async def __aexit__(self, exc_type, exc, tb):
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
"""Exit."""
|
"""Exit."""
|
||||||
|
|
||||||
def __call__(self, host, port):
|
def __call__(self, host: str, port: int):
|
||||||
"""Call."""
|
"""Call."""
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
|
|
|
@ -5,14 +5,23 @@ from unittest.mock import AsyncMock, patch
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components import stt
|
from homeassistant.components import stt
|
||||||
|
from homeassistant.components.wyoming import DOMAIN
|
||||||
|
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from . import STT_INFO, TTS_INFO, WAKE_WORD_INFO
|
from . import SATELLITE_INFO, STT_INFO, TTS_INFO, WAKE_WORD_INFO
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def init_components(hass: HomeAssistant):
|
||||||
|
"""Set up required components."""
|
||||||
|
assert await async_setup_component(hass, "homeassistant", {})
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_setup_entry() -> Generator[AsyncMock, None, None]:
|
def mock_setup_entry() -> Generator[AsyncMock, None, None]:
|
||||||
"""Override async_setup_entry."""
|
"""Override async_setup_entry."""
|
||||||
|
@ -110,3 +119,39 @@ def metadata(hass: HomeAssistant) -> stt.SpeechMetadata:
|
||||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def satellite_config_entry(hass: HomeAssistant) -> ConfigEntry:
|
||||||
|
"""Create a config entry."""
|
||||||
|
entry = MockConfigEntry(
|
||||||
|
domain="wyoming",
|
||||||
|
data={
|
||||||
|
"host": "1.2.3.4",
|
||||||
|
"port": 1234,
|
||||||
|
},
|
||||||
|
title="Test Satellite",
|
||||||
|
)
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
return entry
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def init_satellite(hass: HomeAssistant, satellite_config_entry: ConfigEntry):
|
||||||
|
"""Initialize Wyoming satellite."""
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
|
return_value=SATELLITE_INFO,
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.run"
|
||||||
|
) as _run_mock:
|
||||||
|
# _run_mock: satellite task does not actually run
|
||||||
|
await hass.config_entries.async_setup(satellite_config_entry.entry_id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def satellite_device(
|
||||||
|
hass: HomeAssistant, init_satellite, satellite_config_entry: ConfigEntry
|
||||||
|
) -> SatelliteDevice:
|
||||||
|
"""Get a satellite device fixture."""
|
||||||
|
return hass.data[DOMAIN][satellite_config_entry.entry_id].satellite.device
|
||||||
|
|
|
@ -121,3 +121,45 @@
|
||||||
'version': 1,
|
'version': 1,
|
||||||
})
|
})
|
||||||
# ---
|
# ---
|
||||||
|
# name: test_zeroconf_discovery
|
||||||
|
FlowResultSnapshot({
|
||||||
|
'context': dict({
|
||||||
|
'name': 'Test Satellite',
|
||||||
|
'source': 'zeroconf',
|
||||||
|
'title_placeholders': dict({
|
||||||
|
'name': 'Test Satellite',
|
||||||
|
}),
|
||||||
|
'unique_id': 'test_zeroconf_name._wyoming._tcp.local._Test Satellite',
|
||||||
|
}),
|
||||||
|
'data': dict({
|
||||||
|
'host': '127.0.0.1',
|
||||||
|
'port': 12345,
|
||||||
|
}),
|
||||||
|
'description': None,
|
||||||
|
'description_placeholders': None,
|
||||||
|
'flow_id': <ANY>,
|
||||||
|
'handler': 'wyoming',
|
||||||
|
'options': dict({
|
||||||
|
}),
|
||||||
|
'result': ConfigEntrySnapshot({
|
||||||
|
'data': dict({
|
||||||
|
'host': '127.0.0.1',
|
||||||
|
'port': 12345,
|
||||||
|
}),
|
||||||
|
'disabled_by': None,
|
||||||
|
'domain': 'wyoming',
|
||||||
|
'entry_id': <ANY>,
|
||||||
|
'options': dict({
|
||||||
|
}),
|
||||||
|
'pref_disable_new_entities': False,
|
||||||
|
'pref_disable_polling': False,
|
||||||
|
'source': 'zeroconf',
|
||||||
|
'title': 'Test Satellite',
|
||||||
|
'unique_id': 'test_zeroconf_name._wyoming._tcp.local._Test Satellite',
|
||||||
|
'version': 1,
|
||||||
|
}),
|
||||||
|
'title': 'Test Satellite',
|
||||||
|
'type': <FlowResultType.CREATE_ENTRY: 'create_entry'>,
|
||||||
|
'version': 1,
|
||||||
|
})
|
||||||
|
# ---
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
"""Test Wyoming binary sensor devices."""
|
||||||
|
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.const import STATE_OFF, STATE_ON
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
|
||||||
|
async def test_assist_in_progress(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
satellite_config_entry: ConfigEntry,
|
||||||
|
satellite_device: SatelliteDevice,
|
||||||
|
) -> None:
|
||||||
|
"""Test assist in progress."""
|
||||||
|
assist_in_progress_id = satellite_device.get_assist_in_progress_entity_id(hass)
|
||||||
|
assert assist_in_progress_id
|
||||||
|
|
||||||
|
state = hass.states.get(assist_in_progress_id)
|
||||||
|
assert state is not None
|
||||||
|
assert state.state == STATE_OFF
|
||||||
|
assert not satellite_device.is_active
|
||||||
|
|
||||||
|
satellite_device.set_is_active(True)
|
||||||
|
|
||||||
|
state = hass.states.get(assist_in_progress_id)
|
||||||
|
assert state is not None
|
||||||
|
assert state.state == STATE_ON
|
||||||
|
assert satellite_device.is_active
|
||||||
|
|
||||||
|
satellite_device.set_is_active(False)
|
||||||
|
|
||||||
|
state = hass.states.get(assist_in_progress_id)
|
||||||
|
assert state is not None
|
||||||
|
assert state.state == STATE_OFF
|
||||||
|
assert not satellite_device.is_active
|
|
@ -1,4 +1,5 @@
|
||||||
"""Test the Wyoming config flow."""
|
"""Test the Wyoming config flow."""
|
||||||
|
from ipaddress import IPv4Address
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -8,10 +9,11 @@ from wyoming.info import Info
|
||||||
from homeassistant import config_entries
|
from homeassistant import config_entries
|
||||||
from homeassistant.components.hassio import HassioServiceInfo
|
from homeassistant.components.hassio import HassioServiceInfo
|
||||||
from homeassistant.components.wyoming.const import DOMAIN
|
from homeassistant.components.wyoming.const import DOMAIN
|
||||||
|
from homeassistant.components.zeroconf import ZeroconfServiceInfo
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.data_entry_flow import FlowResultType
|
from homeassistant.data_entry_flow import FlowResultType
|
||||||
|
|
||||||
from . import EMPTY_INFO, STT_INFO, TTS_INFO
|
from . import EMPTY_INFO, SATELLITE_INFO, STT_INFO, TTS_INFO
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
@ -25,6 +27,16 @@ ADDON_DISCOVERY = HassioServiceInfo(
|
||||||
uuid="1234",
|
uuid="1234",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ZEROCONF_DISCOVERY = ZeroconfServiceInfo(
|
||||||
|
ip_address=IPv4Address("127.0.0.1"),
|
||||||
|
ip_addresses=[IPv4Address("127.0.0.1")],
|
||||||
|
port=12345,
|
||||||
|
hostname="localhost",
|
||||||
|
type="_wyoming._tcp.local.",
|
||||||
|
name="test_zeroconf_name._wyoming._tcp.local.",
|
||||||
|
properties={},
|
||||||
|
)
|
||||||
|
|
||||||
pytestmark = pytest.mark.usefixtures("mock_setup_entry")
|
pytestmark = pytest.mark.usefixtures("mock_setup_entry")
|
||||||
|
|
||||||
|
|
||||||
|
@ -214,3 +226,70 @@ async def test_hassio_addon_no_supported_services(hass: HomeAssistant) -> None:
|
||||||
|
|
||||||
assert result2.get("type") == FlowResultType.ABORT
|
assert result2.get("type") == FlowResultType.ABORT
|
||||||
assert result2.get("reason") == "no_services"
|
assert result2.get("reason") == "no_services"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_zeroconf_discovery(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_setup_entry: AsyncMock,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test config flow initiated by Supervisor."""
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
|
return_value=SATELLITE_INFO,
|
||||||
|
):
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN,
|
||||||
|
data=ZEROCONF_DISCOVERY,
|
||||||
|
context={"source": config_entries.SOURCE_ZEROCONF},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.get("type") == FlowResultType.FORM
|
||||||
|
assert result.get("step_id") == "zeroconf_confirm"
|
||||||
|
assert result.get("description_placeholders") == {
|
||||||
|
"name": SATELLITE_INFO.satellite.name
|
||||||
|
}
|
||||||
|
|
||||||
|
result2 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
|
||||||
|
assert result2.get("type") == FlowResultType.CREATE_ENTRY
|
||||||
|
assert result2 == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
async def test_zeroconf_discovery_no_port(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_setup_entry: AsyncMock,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test discovery when the zeroconf service does not have a port."""
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
|
return_value=SATELLITE_INFO,
|
||||||
|
), patch.object(ZEROCONF_DISCOVERY, "port", None):
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN,
|
||||||
|
data=ZEROCONF_DISCOVERY,
|
||||||
|
context={"source": config_entries.SOURCE_ZEROCONF},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.get("type") == FlowResultType.ABORT
|
||||||
|
assert result.get("reason") == "no_port"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_zeroconf_discovery_no_services(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_setup_entry: AsyncMock,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test discovery when there are no supported services on the client."""
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
|
return_value=Info(),
|
||||||
|
):
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN,
|
||||||
|
data=ZEROCONF_DISCOVERY,
|
||||||
|
context={"source": config_entries.SOURCE_ZEROCONF},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.get("type") == FlowResultType.ABORT
|
||||||
|
assert result.get("reason") == "no_services"
|
||||||
|
|
|
@ -3,13 +3,15 @@ from __future__ import annotations
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from homeassistant.components.wyoming.data import load_wyoming_info
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
|
from homeassistant.components.wyoming.data import WyomingService, load_wyoming_info
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
from . import STT_INFO, MockAsyncTcpClient
|
from . import SATELLITE_INFO, STT_INFO, TTS_INFO, WAKE_WORD_INFO, MockAsyncTcpClient
|
||||||
|
|
||||||
|
|
||||||
async def test_load_info(hass: HomeAssistant, snapshot) -> None:
|
async def test_load_info(hass: HomeAssistant, snapshot: SnapshotAssertion) -> None:
|
||||||
"""Test loading info."""
|
"""Test loading info."""
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.wyoming.data.AsyncTcpClient",
|
"homeassistant.components.wyoming.data.AsyncTcpClient",
|
||||||
|
@ -38,3 +40,38 @@ async def test_load_info_oserror(hass: HomeAssistant) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
assert info is None
|
assert info is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_service_name(hass: HomeAssistant) -> None:
|
||||||
|
"""Test loading service info."""
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.data.AsyncTcpClient",
|
||||||
|
MockAsyncTcpClient([STT_INFO.event()]),
|
||||||
|
):
|
||||||
|
service = await WyomingService.create("localhost", 1234)
|
||||||
|
assert service is not None
|
||||||
|
assert service.get_name() == STT_INFO.asr[0].name
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.data.AsyncTcpClient",
|
||||||
|
MockAsyncTcpClient([TTS_INFO.event()]),
|
||||||
|
):
|
||||||
|
service = await WyomingService.create("localhost", 1234)
|
||||||
|
assert service is not None
|
||||||
|
assert service.get_name() == TTS_INFO.tts[0].name
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.data.AsyncTcpClient",
|
||||||
|
MockAsyncTcpClient([WAKE_WORD_INFO.event()]),
|
||||||
|
):
|
||||||
|
service = await WyomingService.create("localhost", 1234)
|
||||||
|
assert service is not None
|
||||||
|
assert service.get_name() == WAKE_WORD_INFO.wake[0].name
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.data.AsyncTcpClient",
|
||||||
|
MockAsyncTcpClient([SATELLITE_INFO.event()]),
|
||||||
|
):
|
||||||
|
service = await WyomingService.create("localhost", 1234)
|
||||||
|
assert service is not None
|
||||||
|
assert service.get_name() == SATELLITE_INFO.satellite.name
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
"""Test Wyoming devices."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
|
||||||
|
from homeassistant.components.wyoming import DOMAIN
|
||||||
|
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.const import STATE_OFF, STATE_ON
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.helpers import device_registry as dr
|
||||||
|
|
||||||
|
|
||||||
|
async def test_device_registry_info(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
satellite_device: SatelliteDevice,
|
||||||
|
satellite_config_entry: ConfigEntry,
|
||||||
|
device_registry: dr.DeviceRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Test info in device registry."""
|
||||||
|
|
||||||
|
# Satellite uses config entry id since only one satellite per entry is
|
||||||
|
# supported.
|
||||||
|
device = device_registry.async_get_device(
|
||||||
|
identifiers={(DOMAIN, satellite_config_entry.entry_id)}
|
||||||
|
)
|
||||||
|
assert device is not None
|
||||||
|
assert device.name == "Test Satellite"
|
||||||
|
assert device.suggested_area == "Office"
|
||||||
|
|
||||||
|
# Check associated entities
|
||||||
|
assist_in_progress_id = satellite_device.get_assist_in_progress_entity_id(hass)
|
||||||
|
assert assist_in_progress_id
|
||||||
|
assist_in_progress_state = hass.states.get(assist_in_progress_id)
|
||||||
|
assert assist_in_progress_state is not None
|
||||||
|
assert assist_in_progress_state.state == STATE_OFF
|
||||||
|
|
||||||
|
satellite_enabled_id = satellite_device.get_satellite_enabled_entity_id(hass)
|
||||||
|
assert satellite_enabled_id
|
||||||
|
satellite_enabled_state = hass.states.get(satellite_enabled_id)
|
||||||
|
assert satellite_enabled_state is not None
|
||||||
|
assert satellite_enabled_state.state == STATE_ON
|
||||||
|
|
||||||
|
pipeline_entity_id = satellite_device.get_pipeline_entity_id(hass)
|
||||||
|
assert pipeline_entity_id
|
||||||
|
pipeline_state = hass.states.get(pipeline_entity_id)
|
||||||
|
assert pipeline_state is not None
|
||||||
|
assert pipeline_state.state == OPTION_PREFERRED
|
||||||
|
|
||||||
|
|
||||||
|
async def test_remove_device_registry_entry(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
satellite_device: SatelliteDevice,
|
||||||
|
device_registry: dr.DeviceRegistry,
|
||||||
|
) -> None:
|
||||||
|
"""Test removing a device registry entry."""
|
||||||
|
|
||||||
|
# Check associated entities
|
||||||
|
assist_in_progress_id = satellite_device.get_assist_in_progress_entity_id(hass)
|
||||||
|
assert assist_in_progress_id
|
||||||
|
assert hass.states.get(assist_in_progress_id) is not None
|
||||||
|
|
||||||
|
satellite_enabled_id = satellite_device.get_satellite_enabled_entity_id(hass)
|
||||||
|
assert satellite_enabled_id
|
||||||
|
assert hass.states.get(satellite_enabled_id) is not None
|
||||||
|
|
||||||
|
pipeline_entity_id = satellite_device.get_pipeline_entity_id(hass)
|
||||||
|
assert pipeline_entity_id
|
||||||
|
assert hass.states.get(pipeline_entity_id) is not None
|
||||||
|
|
||||||
|
# Remove
|
||||||
|
device_registry.async_remove_device(satellite_device.device_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# Everything should be gone
|
||||||
|
assert hass.states.get(assist_in_progress_id) is None
|
||||||
|
assert hass.states.get(satellite_enabled_id) is None
|
||||||
|
assert hass.states.get(pipeline_entity_id) is None
|
|
@ -0,0 +1,460 @@
|
||||||
|
"""Test Wyoming satellite."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
from unittest.mock import patch
|
||||||
|
import wave
|
||||||
|
|
||||||
|
from wyoming.asr import Transcribe, Transcript
|
||||||
|
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
||||||
|
from wyoming.event import Event
|
||||||
|
from wyoming.pipeline import PipelineStage, RunPipeline
|
||||||
|
from wyoming.satellite import RunSatellite
|
||||||
|
from wyoming.tts import Synthesize
|
||||||
|
from wyoming.vad import VoiceStarted, VoiceStopped
|
||||||
|
from wyoming.wake import Detect, Detection
|
||||||
|
|
||||||
|
from homeassistant.components import assist_pipeline, wyoming
|
||||||
|
from homeassistant.components.wyoming.data import WyomingService
|
||||||
|
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from . import SATELLITE_INFO, MockAsyncTcpClient
|
||||||
|
|
||||||
|
from tests.common import MockConfigEntry
|
||||||
|
|
||||||
|
|
||||||
|
async def setup_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||||
|
"""Set up config entry for Wyoming satellite.
|
||||||
|
|
||||||
|
This is separated from the satellite_config_entry method in conftest.py so
|
||||||
|
we can patch functions before the satellite task is run during setup.
|
||||||
|
"""
|
||||||
|
entry = MockConfigEntry(
|
||||||
|
domain="wyoming",
|
||||||
|
data={
|
||||||
|
"host": "1.2.3.4",
|
||||||
|
"port": 1234,
|
||||||
|
},
|
||||||
|
title="Test Satellite",
|
||||||
|
)
|
||||||
|
entry.add_to_hass(hass)
|
||||||
|
await hass.config_entries.async_setup(entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
return entry
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_wav() -> bytes:
|
||||||
|
"""Get bytes for test WAV file."""
|
||||||
|
with io.BytesIO() as wav_io:
|
||||||
|
with wave.open(wav_io, "wb") as wav_file:
|
||||||
|
wav_file.setframerate(22050)
|
||||||
|
wav_file.setsampwidth(2)
|
||||||
|
wav_file.setnchannels(1)
|
||||||
|
|
||||||
|
# Single frame
|
||||||
|
wav_file.writeframes(b"123")
|
||||||
|
|
||||||
|
return wav_io.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
class SatelliteAsyncTcpClient(MockAsyncTcpClient):
|
||||||
|
"""Satellite AsyncTcpClient."""
|
||||||
|
|
||||||
|
def __init__(self, responses: list[Event]) -> None:
|
||||||
|
"""Initialize client."""
|
||||||
|
super().__init__(responses)
|
||||||
|
|
||||||
|
self.connect_event = asyncio.Event()
|
||||||
|
self.run_satellite_event = asyncio.Event()
|
||||||
|
self.detect_event = asyncio.Event()
|
||||||
|
|
||||||
|
self.detection_event = asyncio.Event()
|
||||||
|
self.detection: Detection | None = None
|
||||||
|
|
||||||
|
self.transcribe_event = asyncio.Event()
|
||||||
|
self.transcribe: Transcribe | None = None
|
||||||
|
|
||||||
|
self.voice_started_event = asyncio.Event()
|
||||||
|
self.voice_started: VoiceStarted | None = None
|
||||||
|
|
||||||
|
self.voice_stopped_event = asyncio.Event()
|
||||||
|
self.voice_stopped: VoiceStopped | None = None
|
||||||
|
|
||||||
|
self.transcript_event = asyncio.Event()
|
||||||
|
self.transcript: Transcript | None = None
|
||||||
|
|
||||||
|
self.synthesize_event = asyncio.Event()
|
||||||
|
self.synthesize: Synthesize | None = None
|
||||||
|
|
||||||
|
self.tts_audio_start_event = asyncio.Event()
|
||||||
|
self.tts_audio_chunk_event = asyncio.Event()
|
||||||
|
self.tts_audio_stop_event = asyncio.Event()
|
||||||
|
self.tts_audio_chunk: AudioChunk | None = None
|
||||||
|
|
||||||
|
self._mic_audio_chunk = AudioChunk(
|
||||||
|
rate=16000, width=2, channels=1, audio=b"chunk"
|
||||||
|
).event()
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""Connect."""
|
||||||
|
self.connect_event.set()
|
||||||
|
|
||||||
|
async def write_event(self, event: Event):
|
||||||
|
"""Send."""
|
||||||
|
if RunSatellite.is_type(event.type):
|
||||||
|
self.run_satellite_event.set()
|
||||||
|
elif Detect.is_type(event.type):
|
||||||
|
self.detect_event.set()
|
||||||
|
elif Detection.is_type(event.type):
|
||||||
|
self.detection = Detection.from_event(event)
|
||||||
|
self.detection_event.set()
|
||||||
|
elif Transcribe.is_type(event.type):
|
||||||
|
self.transcribe = Transcribe.from_event(event)
|
||||||
|
self.transcribe_event.set()
|
||||||
|
elif VoiceStarted.is_type(event.type):
|
||||||
|
self.voice_started = VoiceStarted.from_event(event)
|
||||||
|
self.voice_started_event.set()
|
||||||
|
elif VoiceStopped.is_type(event.type):
|
||||||
|
self.voice_stopped = VoiceStopped.from_event(event)
|
||||||
|
self.voice_stopped_event.set()
|
||||||
|
elif Transcript.is_type(event.type):
|
||||||
|
self.transcript = Transcript.from_event(event)
|
||||||
|
self.transcript_event.set()
|
||||||
|
elif Synthesize.is_type(event.type):
|
||||||
|
self.synthesize = Synthesize.from_event(event)
|
||||||
|
self.synthesize_event.set()
|
||||||
|
elif AudioStart.is_type(event.type):
|
||||||
|
self.tts_audio_start_event.set()
|
||||||
|
elif AudioChunk.is_type(event.type):
|
||||||
|
self.tts_audio_chunk = AudioChunk.from_event(event)
|
||||||
|
self.tts_audio_chunk_event.set()
|
||||||
|
elif AudioStop.is_type(event.type):
|
||||||
|
self.tts_audio_stop_event.set()
|
||||||
|
|
||||||
|
async def read_event(self) -> Event | None:
|
||||||
|
"""Receive."""
|
||||||
|
event = await super().read_event()
|
||||||
|
|
||||||
|
# Keep sending audio chunks instead of None
|
||||||
|
return event or self._mic_audio_chunk
|
||||||
|
|
||||||
|
|
||||||
|
async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
||||||
|
"""Test running a pipeline with a satellite."""
|
||||||
|
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
||||||
|
|
||||||
|
events = [
|
||||||
|
RunPipeline(
|
||||||
|
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
||||||
|
).event(),
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
|
return_value=SATELLITE_INFO,
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||||
|
SatelliteAsyncTcpClient(events),
|
||||||
|
) as mock_client, patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||||
|
) as mock_run_pipeline, patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
|
||||||
|
return_value=("wav", get_test_wav()),
|
||||||
|
):
|
||||||
|
entry = await setup_config_entry(hass)
|
||||||
|
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||||
|
entry.entry_id
|
||||||
|
].satellite.device
|
||||||
|
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await mock_client.connect_event.wait()
|
||||||
|
await mock_client.run_satellite_event.wait()
|
||||||
|
|
||||||
|
mock_run_pipeline.assert_called()
|
||||||
|
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
|
||||||
|
|
||||||
|
# Start detecting wake word
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
assist_pipeline.PipelineEventType.WAKE_WORD_START
|
||||||
|
)
|
||||||
|
)
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await mock_client.detect_event.wait()
|
||||||
|
|
||||||
|
assert not device.is_active
|
||||||
|
assert device.is_enabled
|
||||||
|
|
||||||
|
# Wake word is detected
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
assist_pipeline.PipelineEventType.WAKE_WORD_END,
|
||||||
|
{"wake_word_output": {"wake_word_id": "test_wake_word"}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await mock_client.detection_event.wait()
|
||||||
|
|
||||||
|
assert mock_client.detection is not None
|
||||||
|
assert mock_client.detection.name == "test_wake_word"
|
||||||
|
|
||||||
|
# "Assist in progress" sensor should be active now
|
||||||
|
assert device.is_active
|
||||||
|
|
||||||
|
# Speech-to-text started
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
assist_pipeline.PipelineEventType.STT_START,
|
||||||
|
{"metadata": {"language": "en"}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await mock_client.transcribe_event.wait()
|
||||||
|
|
||||||
|
assert mock_client.transcribe is not None
|
||||||
|
assert mock_client.transcribe.language == "en"
|
||||||
|
|
||||||
|
# User started speaking
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
assist_pipeline.PipelineEventType.STT_VAD_START, {"timestamp": 1234}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await mock_client.voice_started_event.wait()
|
||||||
|
|
||||||
|
assert mock_client.voice_started is not None
|
||||||
|
assert mock_client.voice_started.timestamp == 1234
|
||||||
|
|
||||||
|
# User stopped speaking
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
assist_pipeline.PipelineEventType.STT_VAD_END, {"timestamp": 5678}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await mock_client.voice_stopped_event.wait()
|
||||||
|
|
||||||
|
assert mock_client.voice_stopped is not None
|
||||||
|
assert mock_client.voice_stopped.timestamp == 5678
|
||||||
|
|
||||||
|
# Speech-to-text transcription
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
assist_pipeline.PipelineEventType.STT_END,
|
||||||
|
{"stt_output": {"text": "test transcript"}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await mock_client.transcript_event.wait()
|
||||||
|
|
||||||
|
assert mock_client.transcript is not None
|
||||||
|
assert mock_client.transcript.text == "test transcript"
|
||||||
|
|
||||||
|
# Text-to-speech text
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
assist_pipeline.PipelineEventType.TTS_START,
|
||||||
|
{
|
||||||
|
"tts_input": "test text to speak",
|
||||||
|
"voice": "test voice",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await mock_client.synthesize_event.wait()
|
||||||
|
|
||||||
|
assert mock_client.synthesize is not None
|
||||||
|
assert mock_client.synthesize.text == "test text to speak"
|
||||||
|
assert mock_client.synthesize.voice is not None
|
||||||
|
assert mock_client.synthesize.voice.name == "test voice"
|
||||||
|
|
||||||
|
# Text-to-speech media
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(
|
||||||
|
assist_pipeline.PipelineEventType.TTS_END,
|
||||||
|
{"tts_output": {"media_id": "test media id"}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await mock_client.tts_audio_start_event.wait()
|
||||||
|
await mock_client.tts_audio_chunk_event.wait()
|
||||||
|
await mock_client.tts_audio_stop_event.wait()
|
||||||
|
|
||||||
|
# Verify audio chunk from test WAV
|
||||||
|
assert mock_client.tts_audio_chunk is not None
|
||||||
|
assert mock_client.tts_audio_chunk.rate == 22050
|
||||||
|
assert mock_client.tts_audio_chunk.width == 2
|
||||||
|
assert mock_client.tts_audio_chunk.channels == 1
|
||||||
|
assert mock_client.tts_audio_chunk.audio == b"123"
|
||||||
|
|
||||||
|
# Pipeline finished
|
||||||
|
event_callback(
|
||||||
|
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
|
||||||
|
)
|
||||||
|
assert not device.is_active
|
||||||
|
|
||||||
|
# Stop the satellite
|
||||||
|
await hass.config_entries.async_unload(entry.entry_id)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_satellite_disabled(hass: HomeAssistant) -> None:
|
||||||
|
"""Test callback for a satellite that has been disabled."""
|
||||||
|
on_disabled_event = asyncio.Event()
|
||||||
|
|
||||||
|
original_make_satellite = wyoming._make_satellite
|
||||||
|
|
||||||
|
def make_disabled_satellite(
|
||||||
|
hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService
|
||||||
|
):
|
||||||
|
satellite = original_make_satellite(hass, config_entry, service)
|
||||||
|
satellite.device.is_enabled = False
|
||||||
|
|
||||||
|
return satellite
|
||||||
|
|
||||||
|
async def on_disabled(self):
|
||||||
|
on_disabled_event.set()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
|
return_value=SATELLITE_INFO,
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.wyoming._make_satellite", make_disabled_satellite
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_disabled",
|
||||||
|
on_disabled,
|
||||||
|
):
|
||||||
|
await setup_config_entry(hass)
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await on_disabled_event.wait()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_satellite_restart(hass: HomeAssistant) -> None:
|
||||||
|
"""Test pipeline loop restart after unexpected error."""
|
||||||
|
on_restart_event = asyncio.Event()
|
||||||
|
|
||||||
|
async def on_restart(self):
|
||||||
|
self.stop()
|
||||||
|
on_restart_event.set()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
|
return_value=SATELLITE_INFO,
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.WyomingSatellite._run_once",
|
||||||
|
side_effect=RuntimeError(),
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
||||||
|
on_restart,
|
||||||
|
):
|
||||||
|
await setup_config_entry(hass)
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await on_restart_event.wait()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_satellite_reconnect(hass: HomeAssistant) -> None:
|
||||||
|
"""Test satellite reconnect call after connection refused."""
|
||||||
|
on_reconnect_event = asyncio.Event()
|
||||||
|
|
||||||
|
async def on_reconnect(self):
|
||||||
|
self.stop()
|
||||||
|
on_reconnect_event.set()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
|
return_value=SATELLITE_INFO,
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient.connect",
|
||||||
|
side_effect=ConnectionRefusedError(),
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect",
|
||||||
|
on_reconnect,
|
||||||
|
):
|
||||||
|
await setup_config_entry(hass)
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await on_reconnect_event.wait()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None:
|
||||||
|
"""Test satellite disconnecting before pipeline run."""
|
||||||
|
on_restart_event = asyncio.Event()
|
||||||
|
|
||||||
|
async def on_restart(self):
|
||||||
|
self.stop()
|
||||||
|
on_restart_event.set()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
|
return_value=SATELLITE_INFO,
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||||
|
MockAsyncTcpClient([]), # no RunPipeline event
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||||
|
) as mock_run_pipeline, patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
||||||
|
on_restart,
|
||||||
|
):
|
||||||
|
await setup_config_entry(hass)
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await on_restart_event.wait()
|
||||||
|
|
||||||
|
# Pipeline should never have run
|
||||||
|
mock_run_pipeline.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None:
|
||||||
|
"""Test satellite disconnecting during pipeline run."""
|
||||||
|
events = [
|
||||||
|
RunPipeline(
|
||||||
|
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
||||||
|
).event(),
|
||||||
|
] # no audio chunks after RunPipeline
|
||||||
|
|
||||||
|
on_restart_event = asyncio.Event()
|
||||||
|
on_stopped_event = asyncio.Event()
|
||||||
|
|
||||||
|
async def on_restart(self):
|
||||||
|
# Pretend sensor got stuck on
|
||||||
|
self.device.is_active = True
|
||||||
|
self.stop()
|
||||||
|
on_restart_event.set()
|
||||||
|
|
||||||
|
async def on_stopped(self):
|
||||||
|
on_stopped_event.set()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||||
|
return_value=SATELLITE_INFO,
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||||
|
MockAsyncTcpClient(events),
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||||
|
) as mock_run_pipeline, patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
||||||
|
on_restart,
|
||||||
|
), patch(
|
||||||
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
|
||||||
|
on_stopped,
|
||||||
|
):
|
||||||
|
entry = await setup_config_entry(hass)
|
||||||
|
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||||
|
entry.entry_id
|
||||||
|
].satellite.device
|
||||||
|
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await on_restart_event.wait()
|
||||||
|
await on_stopped_event.wait()
|
||||||
|
|
||||||
|
# Pipeline should have run once
|
||||||
|
mock_run_pipeline.assert_called_once()
|
||||||
|
|
||||||
|
# Sensor should have been turned off
|
||||||
|
assert not device.is_active
|
|
@ -0,0 +1,83 @@
|
||||||
|
"""Test Wyoming select."""
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from homeassistant.components import assist_pipeline
|
||||||
|
from homeassistant.components.assist_pipeline.pipeline import PipelineData
|
||||||
|
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
|
||||||
|
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_select(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
satellite_config_entry: ConfigEntry,
|
||||||
|
satellite_device: SatelliteDevice,
|
||||||
|
) -> None:
|
||||||
|
"""Test pipeline select.
|
||||||
|
|
||||||
|
Functionality is tested in assist_pipeline/test_select.py.
|
||||||
|
This test is only to ensure it is set up.
|
||||||
|
"""
|
||||||
|
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
||||||
|
pipeline_data: PipelineData = hass.data[assist_pipeline.DOMAIN]
|
||||||
|
|
||||||
|
# Create second pipeline
|
||||||
|
await pipeline_data.pipeline_store.async_create_item(
|
||||||
|
{
|
||||||
|
"name": "Test 1",
|
||||||
|
"language": "en-US",
|
||||||
|
"conversation_engine": None,
|
||||||
|
"conversation_language": "en-US",
|
||||||
|
"tts_engine": None,
|
||||||
|
"tts_language": None,
|
||||||
|
"tts_voice": None,
|
||||||
|
"stt_engine": None,
|
||||||
|
"stt_language": None,
|
||||||
|
"wake_word_entity": None,
|
||||||
|
"wake_word_id": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Preferred pipeline is the default
|
||||||
|
pipeline_entity_id = satellite_device.get_pipeline_entity_id(hass)
|
||||||
|
assert pipeline_entity_id
|
||||||
|
|
||||||
|
state = hass.states.get(pipeline_entity_id)
|
||||||
|
assert state is not None
|
||||||
|
assert state.state == OPTION_PREFERRED
|
||||||
|
|
||||||
|
# Change to second pipeline
|
||||||
|
with patch.object(satellite_device, "set_pipeline_name") as mock_pipeline_changed:
|
||||||
|
await hass.services.async_call(
|
||||||
|
"select",
|
||||||
|
"select_option",
|
||||||
|
{"entity_id": pipeline_entity_id, "option": "Test 1"},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
state = hass.states.get(pipeline_entity_id)
|
||||||
|
assert state is not None
|
||||||
|
assert state.state == "Test 1"
|
||||||
|
|
||||||
|
# async_pipeline_changed should have been called
|
||||||
|
mock_pipeline_changed.assert_called_once_with("Test 1")
|
||||||
|
|
||||||
|
# Change back and check update listener
|
||||||
|
pipeline_listener = Mock()
|
||||||
|
satellite_device.set_pipeline_listener(pipeline_listener)
|
||||||
|
|
||||||
|
await hass.services.async_call(
|
||||||
|
"select",
|
||||||
|
"select_option",
|
||||||
|
{"entity_id": pipeline_entity_id, "option": OPTION_PREFERRED},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
state = hass.states.get(pipeline_entity_id)
|
||||||
|
assert state is not None
|
||||||
|
assert state.state == OPTION_PREFERRED
|
||||||
|
|
||||||
|
# listener should have been called
|
||||||
|
pipeline_listener.assert_called_once()
|
|
@ -0,0 +1,32 @@
|
||||||
|
"""Test Wyoming switch devices."""
|
||||||
|
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
|
from homeassistant.const import STATE_OFF, STATE_ON
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
|
||||||
|
async def test_satellite_enabled(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
satellite_config_entry: ConfigEntry,
|
||||||
|
satellite_device: SatelliteDevice,
|
||||||
|
) -> None:
|
||||||
|
"""Test satellite enabled."""
|
||||||
|
satellite_enabled_id = satellite_device.get_satellite_enabled_entity_id(hass)
|
||||||
|
assert satellite_enabled_id
|
||||||
|
|
||||||
|
state = hass.states.get(satellite_enabled_id)
|
||||||
|
assert state is not None
|
||||||
|
assert state.state == STATE_ON
|
||||||
|
assert satellite_device.is_enabled
|
||||||
|
|
||||||
|
await hass.services.async_call(
|
||||||
|
"switch",
|
||||||
|
"turn_off",
|
||||||
|
{"entity_id": satellite_enabled_id},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
state = hass.states.get(satellite_enabled_id)
|
||||||
|
assert state is not None
|
||||||
|
assert state.state == STATE_OFF
|
||||||
|
assert not satellite_device.is_enabled
|
Loading…
Reference in New Issue