Add Wyoming satellite (#104759)
* First draft of Wyoming satellite * Set up homeassistant in tests * Move satellite * Add devices with binary sensor and select * Add more events * Add satellite enabled switch * Fix mistake * Only set up necessary platforms for satellites * Lots of fixes * Add tests * Use config entry id as satellite id * Initial satellite test * Add satellite pipeline test * More tests * More satellite tests * Only support single device per config entry * Address comments * Make a copy of platformspull/105135/head
parent
db6b804298
commit
5a49e1dd5c
|
@ -4,17 +4,26 @@ from __future__ import annotations
|
|||
import logging
|
||||
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import ConfigEntryNotReady
|
||||
from homeassistant.helpers import device_registry as dr
|
||||
|
||||
from .const import ATTR_SPEAKER, DOMAIN
|
||||
from .data import WyomingService
|
||||
from .devices import SatelliteDevice
|
||||
from .models import DomainDataItem
|
||||
from .satellite import WyomingSatellite
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
SATELLITE_PLATFORMS = [Platform.BINARY_SENSOR, Platform.SELECT, Platform.SWITCH]
|
||||
|
||||
__all__ = [
|
||||
"ATTR_SPEAKER",
|
||||
"DOMAIN",
|
||||
"async_setup_entry",
|
||||
"async_unload_entry",
|
||||
]
|
||||
|
||||
|
||||
|
@ -25,24 +34,72 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||
if service is None:
|
||||
raise ConfigEntryNotReady("Unable to connect")
|
||||
|
||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = service
|
||||
item = DomainDataItem(service=service)
|
||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = item
|
||||
|
||||
await hass.config_entries.async_forward_entry_setups(
|
||||
entry,
|
||||
service.platforms,
|
||||
)
|
||||
await hass.config_entries.async_forward_entry_setups(entry, service.platforms)
|
||||
entry.async_on_unload(entry.add_update_listener(update_listener))
|
||||
|
||||
if (satellite_info := service.info.satellite) is not None:
|
||||
# Create satellite device, etc.
|
||||
item.satellite = _make_satellite(hass, entry, service)
|
||||
|
||||
# Set up satellite sensors, switches, etc.
|
||||
await hass.config_entries.async_forward_entry_setups(entry, SATELLITE_PLATFORMS)
|
||||
|
||||
# Start satellite communication
|
||||
entry.async_create_background_task(
|
||||
hass,
|
||||
item.satellite.run(),
|
||||
f"Satellite {satellite_info.name}",
|
||||
)
|
||||
|
||||
entry.async_on_unload(item.satellite.stop)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _make_satellite(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService
|
||||
) -> WyomingSatellite:
|
||||
"""Create Wyoming satellite/device from config entry and Wyoming service."""
|
||||
satellite_info = service.info.satellite
|
||||
assert satellite_info is not None
|
||||
|
||||
dev_reg = dr.async_get(hass)
|
||||
|
||||
# Use config entry id since only one satellite per entry is supported
|
||||
satellite_id = config_entry.entry_id
|
||||
|
||||
device = dev_reg.async_get_or_create(
|
||||
config_entry_id=config_entry.entry_id,
|
||||
identifiers={(DOMAIN, satellite_id)},
|
||||
name=satellite_info.name,
|
||||
suggested_area=satellite_info.area,
|
||||
)
|
||||
|
||||
satellite_device = SatelliteDevice(
|
||||
satellite_id=satellite_id,
|
||||
device_id=device.id,
|
||||
)
|
||||
|
||||
return WyomingSatellite(hass, service, satellite_device)
|
||||
|
||||
|
||||
async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
|
||||
"""Handle options update."""
|
||||
await hass.config_entries.async_reload(entry.entry_id)
|
||||
|
||||
|
||||
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Unload Wyoming."""
|
||||
service: WyomingService = hass.data[DOMAIN][entry.entry_id]
|
||||
item: DomainDataItem = hass.data[DOMAIN][entry.entry_id]
|
||||
|
||||
unload_ok = await hass.config_entries.async_unload_platforms(
|
||||
entry,
|
||||
service.platforms,
|
||||
)
|
||||
platforms = list(item.service.platforms)
|
||||
if item.satellite is not None:
|
||||
platforms += SATELLITE_PLATFORMS
|
||||
|
||||
unload_ok = await hass.config_entries.async_unload_platforms(entry, platforms)
|
||||
if unload_ok:
|
||||
del hass.data[DOMAIN][entry.entry_id]
|
||||
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
"""Binary sensor for Wyoming."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from homeassistant.components.binary_sensor import (
|
||||
BinarySensorEntity,
|
||||
BinarySensorEntityDescription,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .const import DOMAIN
|
||||
from .entity import WyomingSatelliteEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import DomainDataItem
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up binary sensor entities."""
|
||||
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||
|
||||
# Setup is only forwarded for satellites
|
||||
assert item.satellite is not None
|
||||
|
||||
async_add_entities([WyomingSatelliteAssistInProgress(item.satellite.device)])
|
||||
|
||||
|
||||
class WyomingSatelliteAssistInProgress(WyomingSatelliteEntity, BinarySensorEntity):
|
||||
"""Entity to represent Assist is in progress for satellite."""
|
||||
|
||||
entity_description = BinarySensorEntityDescription(
|
||||
key="assist_in_progress",
|
||||
translation_key="assist_in_progress",
|
||||
)
|
||||
_attr_is_on = False
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Call when entity about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
|
||||
self._device.set_is_active_listener(self._is_active_changed)
|
||||
|
||||
@callback
|
||||
def _is_active_changed(self) -> None:
|
||||
"""Call when active state changed."""
|
||||
self._attr_is_on = self._device.is_active
|
||||
self.async_write_ha_state()
|
|
@ -1,19 +1,22 @@
|
|||
"""Config flow for Wyoming integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.hassio import HassioServiceInfo
|
||||
from homeassistant.const import CONF_HOST, CONF_PORT
|
||||
from homeassistant.components import hassio, zeroconf
|
||||
from homeassistant.const import CONF_HOST, CONF_NAME, CONF_PORT
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
|
||||
from .const import DOMAIN
|
||||
from .data import WyomingService
|
||||
|
||||
_LOGGER = logging.getLogger()
|
||||
|
||||
STEP_USER_DATA_SCHEMA = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_HOST): str,
|
||||
|
@ -27,7 +30,9 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
|
||||
VERSION = 1
|
||||
|
||||
_hassio_discovery: HassioServiceInfo
|
||||
_hassio_discovery: hassio.HassioServiceInfo
|
||||
_service: WyomingService | None = None
|
||||
_name: str | None = None
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
|
@ -50,27 +55,14 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
errors={"base": "cannot_connect"},
|
||||
)
|
||||
|
||||
# ASR = automated speech recognition (speech-to-text)
|
||||
asr_installed = [asr for asr in service.info.asr if asr.installed]
|
||||
if name := service.get_name():
|
||||
return self.async_create_entry(title=name, data=user_input)
|
||||
|
||||
# TTS = text-to-speech
|
||||
tts_installed = [tts for tts in service.info.tts if tts.installed]
|
||||
return self.async_abort(reason="no_services")
|
||||
|
||||
# wake-word-detection
|
||||
wake_installed = [wake for wake in service.info.wake if wake.installed]
|
||||
|
||||
if asr_installed:
|
||||
name = asr_installed[0].name
|
||||
elif tts_installed:
|
||||
name = tts_installed[0].name
|
||||
elif wake_installed:
|
||||
name = wake_installed[0].name
|
||||
else:
|
||||
return self.async_abort(reason="no_services")
|
||||
|
||||
return self.async_create_entry(title=name, data=user_input)
|
||||
|
||||
async def async_step_hassio(self, discovery_info: HassioServiceInfo) -> FlowResult:
|
||||
async def async_step_hassio(
|
||||
self, discovery_info: hassio.HassioServiceInfo
|
||||
) -> FlowResult:
|
||||
"""Handle Supervisor add-on discovery."""
|
||||
await self.async_set_unique_id(discovery_info.uuid)
|
||||
self._abort_if_unique_id_configured()
|
||||
|
@ -93,11 +85,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
if user_input is not None:
|
||||
uri = urlparse(self._hassio_discovery.config["uri"])
|
||||
if service := await WyomingService.create(uri.hostname, uri.port):
|
||||
if (
|
||||
not any(asr for asr in service.info.asr if asr.installed)
|
||||
and not any(tts for tts in service.info.tts if tts.installed)
|
||||
and not any(wake for wake in service.info.wake if wake.installed)
|
||||
):
|
||||
if not service.has_services():
|
||||
return self.async_abort(reason="no_services")
|
||||
|
||||
return self.async_create_entry(
|
||||
|
@ -112,3 +100,52 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
|||
description_placeholders={"addon": self._hassio_discovery.name},
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def async_step_zeroconf(
|
||||
self, discovery_info: zeroconf.ZeroconfServiceInfo
|
||||
) -> FlowResult:
|
||||
"""Handle zeroconf discovery."""
|
||||
_LOGGER.debug("Discovery info: %s", discovery_info)
|
||||
if discovery_info.port is None:
|
||||
return self.async_abort(reason="no_port")
|
||||
|
||||
service = await WyomingService.create(discovery_info.host, discovery_info.port)
|
||||
if (service is None) or (not (name := service.get_name())):
|
||||
# No supported services
|
||||
return self.async_abort(reason="no_services")
|
||||
|
||||
self._name = name
|
||||
|
||||
# Use zeroconf name + service name as unique id.
|
||||
# The satellite will use its own MAC as the zeroconf name by default.
|
||||
unique_id = f"{discovery_info.name}_{self._name}"
|
||||
await self.async_set_unique_id(unique_id)
|
||||
self._abort_if_unique_id_configured()
|
||||
|
||||
self.context[CONF_NAME] = self._name
|
||||
self.context["title_placeholders"] = {"name": self._name}
|
||||
|
||||
self._service = service
|
||||
return await self.async_step_zeroconf_confirm()
|
||||
|
||||
async def async_step_zeroconf_confirm(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle a flow initiated by zeroconf."""
|
||||
assert self._service is not None
|
||||
assert self._name is not None
|
||||
|
||||
if user_input is None:
|
||||
return self.async_show_form(
|
||||
step_id="zeroconf_confirm",
|
||||
description_placeholders={"name": self._name},
|
||||
errors={},
|
||||
)
|
||||
|
||||
return self.async_create_entry(
|
||||
title=self._name,
|
||||
data={
|
||||
CONF_HOST: self._service.host,
|
||||
CONF_PORT: self._service.port,
|
||||
},
|
||||
)
|
||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
|
||||
from wyoming.client import AsyncTcpClient
|
||||
from wyoming.info import Describe, Info
|
||||
from wyoming.info import Describe, Info, Satellite
|
||||
|
||||
from homeassistant.const import Platform
|
||||
|
||||
|
@ -32,6 +32,43 @@ class WyomingService:
|
|||
platforms.append(Platform.WAKE_WORD)
|
||||
self.platforms = platforms
|
||||
|
||||
def has_services(self) -> bool:
|
||||
"""Return True if services are installed that Home Assistant can use."""
|
||||
return (
|
||||
any(asr for asr in self.info.asr if asr.installed)
|
||||
or any(tts for tts in self.info.tts if tts.installed)
|
||||
or any(wake for wake in self.info.wake if wake.installed)
|
||||
or ((self.info.satellite is not None) and self.info.satellite.installed)
|
||||
)
|
||||
|
||||
def get_name(self) -> str | None:
|
||||
"""Return name of first installed usable service."""
|
||||
# ASR = automated speech recognition (speech-to-text)
|
||||
asr_installed = [asr for asr in self.info.asr if asr.installed]
|
||||
if asr_installed:
|
||||
return asr_installed[0].name
|
||||
|
||||
# TTS = text-to-speech
|
||||
tts_installed = [tts for tts in self.info.tts if tts.installed]
|
||||
if tts_installed:
|
||||
return tts_installed[0].name
|
||||
|
||||
# wake-word-detection
|
||||
wake_installed = [wake for wake in self.info.wake if wake.installed]
|
||||
if wake_installed:
|
||||
return wake_installed[0].name
|
||||
|
||||
# satellite
|
||||
satellite_installed: Satellite | None = None
|
||||
|
||||
if (self.info.satellite is not None) and self.info.satellite.installed:
|
||||
satellite_installed = self.info.satellite
|
||||
|
||||
if satellite_installed:
|
||||
return satellite_installed.name
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def create(cls, host: str, port: int) -> WyomingService | None:
|
||||
"""Create a Wyoming service."""
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
"""Class to manage satellite devices."""
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
|
||||
from .const import DOMAIN
|
||||
|
||||
|
||||
@dataclass
|
||||
class SatelliteDevice:
|
||||
"""Class to store device."""
|
||||
|
||||
satellite_id: str
|
||||
device_id: str
|
||||
is_active: bool = False
|
||||
is_enabled: bool = True
|
||||
pipeline_name: str | None = None
|
||||
|
||||
_is_active_listener: Callable[[], None] | None = None
|
||||
_is_enabled_listener: Callable[[], None] | None = None
|
||||
_pipeline_listener: Callable[[], None] | None = None
|
||||
|
||||
@callback
|
||||
def set_is_active(self, active: bool) -> None:
|
||||
"""Set active state."""
|
||||
if active != self.is_active:
|
||||
self.is_active = active
|
||||
if self._is_active_listener is not None:
|
||||
self._is_active_listener()
|
||||
|
||||
@callback
|
||||
def set_is_enabled(self, enabled: bool) -> None:
|
||||
"""Set enabled state."""
|
||||
if enabled != self.is_enabled:
|
||||
self.is_enabled = enabled
|
||||
if self._is_enabled_listener is not None:
|
||||
self._is_enabled_listener()
|
||||
|
||||
@callback
|
||||
def set_pipeline_name(self, pipeline_name: str) -> None:
|
||||
"""Inform listeners that pipeline selection has changed."""
|
||||
if pipeline_name != self.pipeline_name:
|
||||
self.pipeline_name = pipeline_name
|
||||
if self._pipeline_listener is not None:
|
||||
self._pipeline_listener()
|
||||
|
||||
@callback
|
||||
def set_is_active_listener(self, is_active_listener: Callable[[], None]) -> None:
|
||||
"""Listen for updates to is_active."""
|
||||
self._is_active_listener = is_active_listener
|
||||
|
||||
@callback
|
||||
def set_is_enabled_listener(self, is_enabled_listener: Callable[[], None]) -> None:
|
||||
"""Listen for updates to is_enabled."""
|
||||
self._is_enabled_listener = is_enabled_listener
|
||||
|
||||
@callback
|
||||
def set_pipeline_listener(self, pipeline_listener: Callable[[], None]) -> None:
|
||||
"""Listen for updates to pipeline."""
|
||||
self._pipeline_listener = pipeline_listener
|
||||
|
||||
def get_assist_in_progress_entity_id(self, hass: HomeAssistant) -> str | None:
|
||||
"""Return entity id for assist in progress binary sensor."""
|
||||
ent_reg = er.async_get(hass)
|
||||
return ent_reg.async_get_entity_id(
|
||||
"binary_sensor", DOMAIN, f"{self.satellite_id}-assist_in_progress"
|
||||
)
|
||||
|
||||
def get_satellite_enabled_entity_id(self, hass: HomeAssistant) -> str | None:
|
||||
"""Return entity id for satellite enabled switch."""
|
||||
ent_reg = er.async_get(hass)
|
||||
return ent_reg.async_get_entity_id(
|
||||
"switch", DOMAIN, f"{self.satellite_id}-satellite_enabled"
|
||||
)
|
||||
|
||||
def get_pipeline_entity_id(self, hass: HomeAssistant) -> str | None:
|
||||
"""Return entity id for pipeline select."""
|
||||
ent_reg = er.async_get(hass)
|
||||
return ent_reg.async_get_entity_id(
|
||||
"select", DOMAIN, f"{self.satellite_id}-pipeline"
|
||||
)
|
|
@ -0,0 +1,24 @@
|
|||
"""Wyoming entities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from homeassistant.helpers import entity
|
||||
from homeassistant.helpers.device_registry import DeviceInfo
|
||||
|
||||
from .const import DOMAIN
|
||||
from .satellite import SatelliteDevice
|
||||
|
||||
|
||||
class WyomingSatelliteEntity(entity.Entity):
|
||||
"""Wyoming satellite entity."""
|
||||
|
||||
_attr_has_entity_name = True
|
||||
_attr_should_poll = False
|
||||
|
||||
def __init__(self, device: SatelliteDevice) -> None:
|
||||
"""Initialize entity."""
|
||||
self._device = device
|
||||
self._attr_unique_id = f"{device.satellite_id}-{self.entity_description.key}"
|
||||
self._attr_device_info = DeviceInfo(
|
||||
identifiers={(DOMAIN, device.satellite_id)},
|
||||
)
|
|
@ -3,7 +3,9 @@
|
|||
"name": "Wyoming Protocol",
|
||||
"codeowners": ["@balloob", "@synesthesiam"],
|
||||
"config_flow": true,
|
||||
"dependencies": ["assist_pipeline"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/wyoming",
|
||||
"iot_class": "local_push",
|
||||
"requirements": ["wyoming==1.2.0"]
|
||||
"requirements": ["wyoming==1.3.0"],
|
||||
"zeroconf": ["_wyoming._tcp.local."]
|
||||
}
|
||||
|
|
|
@ -0,0 +1,13 @@
|
|||
"""Models for wyoming."""
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .data import WyomingService
|
||||
from .satellite import WyomingSatellite
|
||||
|
||||
|
||||
@dataclass
|
||||
class DomainDataItem:
|
||||
"""Domain data item."""
|
||||
|
||||
service: WyomingService
|
||||
satellite: WyomingSatellite | None = None
|
|
@ -0,0 +1,380 @@
|
|||
"""Support for Wyoming satellite services."""
|
||||
import asyncio
|
||||
from collections.abc import AsyncGenerator
|
||||
import io
|
||||
import logging
|
||||
from typing import Final
|
||||
import wave
|
||||
|
||||
from wyoming.asr import Transcribe, Transcript
|
||||
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop
|
||||
from wyoming.client import AsyncTcpClient
|
||||
from wyoming.pipeline import PipelineStage, RunPipeline
|
||||
from wyoming.satellite import RunSatellite
|
||||
from wyoming.tts import Synthesize, SynthesizeVoice
|
||||
from wyoming.vad import VoiceStarted, VoiceStopped
|
||||
from wyoming.wake import Detect, Detection
|
||||
|
||||
from homeassistant.components import assist_pipeline, stt, tts
|
||||
from homeassistant.components.assist_pipeline import select as pipeline_select
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
|
||||
from .const import DOMAIN
|
||||
from .data import WyomingService
|
||||
from .devices import SatelliteDevice
|
||||
|
||||
_LOGGER = logging.getLogger()
|
||||
|
||||
_SAMPLES_PER_CHUNK: Final = 1024
|
||||
_RECONNECT_SECONDS: Final = 10
|
||||
_RESTART_SECONDS: Final = 3
|
||||
|
||||
# Wyoming stage -> Assist stage
|
||||
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
|
||||
PipelineStage.WAKE: assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
PipelineStage.ASR: assist_pipeline.PipelineStage.STT,
|
||||
PipelineStage.HANDLE: assist_pipeline.PipelineStage.INTENT,
|
||||
PipelineStage.TTS: assist_pipeline.PipelineStage.TTS,
|
||||
}
|
||||
|
||||
|
||||
class WyomingSatellite:
|
||||
"""Remove voice satellite running the Wyoming protocol."""
|
||||
|
||||
def __init__(
|
||||
self, hass: HomeAssistant, service: WyomingService, device: SatelliteDevice
|
||||
) -> None:
|
||||
"""Initialize satellite."""
|
||||
self.hass = hass
|
||||
self.service = service
|
||||
self.device = device
|
||||
self.is_enabled = True
|
||||
self.is_running = True
|
||||
|
||||
self._client: AsyncTcpClient | None = None
|
||||
self._chunk_converter = AudioChunkConverter(rate=16000, width=2, channels=1)
|
||||
self._is_pipeline_running = False
|
||||
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
|
||||
self._pipeline_id: str | None = None
|
||||
self._enabled_changed_event = asyncio.Event()
|
||||
|
||||
self.device.set_is_enabled_listener(self._enabled_changed)
|
||||
self.device.set_pipeline_listener(self._pipeline_changed)
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run and maintain a connection to satellite."""
|
||||
_LOGGER.debug("Running satellite task")
|
||||
|
||||
try:
|
||||
while self.is_running:
|
||||
try:
|
||||
# Check if satellite has been disabled
|
||||
if not self.device.is_enabled:
|
||||
await self.on_disabled()
|
||||
if not self.is_running:
|
||||
# Satellite was stopped while waiting to be enabled
|
||||
break
|
||||
|
||||
# Connect and run pipeline loop
|
||||
await self._run_once()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
await self.on_restart()
|
||||
finally:
|
||||
# Ensure sensor is off
|
||||
self.device.set_is_active(False)
|
||||
|
||||
await self.on_stopped()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal satellite task to stop running."""
|
||||
self.is_running = False
|
||||
|
||||
# Unblock waiting for enabled
|
||||
self._enabled_changed_event.set()
|
||||
|
||||
async def on_restart(self) -> None:
|
||||
"""Block until pipeline loop will be restarted."""
|
||||
_LOGGER.warning(
|
||||
"Unexpected error running satellite. Restarting in %s second(s)",
|
||||
_RECONNECT_SECONDS,
|
||||
)
|
||||
await asyncio.sleep(_RESTART_SECONDS)
|
||||
|
||||
async def on_reconnect(self) -> None:
|
||||
"""Block until a reconnection attempt should be made."""
|
||||
_LOGGER.debug(
|
||||
"Failed to connect to satellite. Reconnecting in %s second(s)",
|
||||
_RECONNECT_SECONDS,
|
||||
)
|
||||
await asyncio.sleep(_RECONNECT_SECONDS)
|
||||
|
||||
async def on_disabled(self) -> None:
|
||||
"""Block until device may be enabled again."""
|
||||
await self._enabled_changed_event.wait()
|
||||
|
||||
async def on_stopped(self) -> None:
|
||||
"""Run when run() has fully stopped."""
|
||||
_LOGGER.debug("Satellite task stopped")
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _enabled_changed(self) -> None:
|
||||
"""Run when device enabled status changes."""
|
||||
|
||||
if not self.device.is_enabled:
|
||||
# Cancel any running pipeline
|
||||
self._audio_queue.put_nowait(None)
|
||||
|
||||
self._enabled_changed_event.set()
|
||||
|
||||
def _pipeline_changed(self) -> None:
|
||||
"""Run when device pipeline changes."""
|
||||
|
||||
# Cancel any running pipeline
|
||||
self._audio_queue.put_nowait(None)
|
||||
|
||||
async def _run_once(self) -> None:
|
||||
"""Run pipelines until an error occurs."""
|
||||
self.device.set_is_active(False)
|
||||
|
||||
while self.is_running and self.is_enabled:
|
||||
try:
|
||||
await self._connect()
|
||||
break
|
||||
except ConnectionError:
|
||||
await self.on_reconnect()
|
||||
|
||||
assert self._client is not None
|
||||
_LOGGER.debug("Connected to satellite")
|
||||
|
||||
if (not self.is_running) or (not self.is_enabled):
|
||||
# Run was cancelled or satellite was disabled during connection
|
||||
return
|
||||
|
||||
# Tell satellite that we're ready
|
||||
await self._client.write_event(RunSatellite().event())
|
||||
|
||||
# Wait until we get RunPipeline event
|
||||
run_pipeline: RunPipeline | None = None
|
||||
while self.is_running and self.is_enabled:
|
||||
run_event = await self._client.read_event()
|
||||
if run_event is None:
|
||||
raise ConnectionResetError("Satellite disconnected")
|
||||
|
||||
if RunPipeline.is_type(run_event.type):
|
||||
run_pipeline = RunPipeline.from_event(run_event)
|
||||
break
|
||||
|
||||
_LOGGER.debug("Unexpected event from satellite: %s", run_event)
|
||||
|
||||
assert run_pipeline is not None
|
||||
_LOGGER.debug("Received run information: %s", run_pipeline)
|
||||
|
||||
if (not self.is_running) or (not self.is_enabled):
|
||||
# Run was cancelled or satellite was disabled while waiting for
|
||||
# RunPipeline event.
|
||||
return
|
||||
|
||||
start_stage = _STAGES.get(run_pipeline.start_stage)
|
||||
end_stage = _STAGES.get(run_pipeline.end_stage)
|
||||
|
||||
if start_stage is None:
|
||||
raise ValueError(f"Invalid start stage: {start_stage}")
|
||||
|
||||
if end_stage is None:
|
||||
raise ValueError(f"Invalid end stage: {end_stage}")
|
||||
|
||||
# Each loop is a pipeline run
|
||||
while self.is_running and self.is_enabled:
|
||||
# Use select to get pipeline each time in case it's changed
|
||||
pipeline_id = pipeline_select.get_chosen_pipeline(
|
||||
self.hass,
|
||||
DOMAIN,
|
||||
self.device.satellite_id,
|
||||
)
|
||||
pipeline = assist_pipeline.async_get_pipeline(self.hass, pipeline_id)
|
||||
assert pipeline is not None
|
||||
|
||||
# We will push audio in through a queue
|
||||
self._audio_queue = asyncio.Queue()
|
||||
stt_stream = self._stt_stream()
|
||||
|
||||
# Start pipeline running
|
||||
_LOGGER.debug(
|
||||
"Starting pipeline %s from %s to %s",
|
||||
pipeline.name,
|
||||
start_stage,
|
||||
end_stage,
|
||||
)
|
||||
self._is_pipeline_running = True
|
||||
_pipeline_task = asyncio.create_task(
|
||||
assist_pipeline.async_pipeline_from_audio_stream(
|
||||
self.hass,
|
||||
context=Context(),
|
||||
event_callback=self._event_callback,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language=pipeline.language,
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=stt_stream,
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
tts_audio_output="wav",
|
||||
pipeline_id=pipeline_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Run until pipeline is complete or cancelled with an empty audio chunk
|
||||
while self._is_pipeline_running:
|
||||
client_event = await self._client.read_event()
|
||||
if client_event is None:
|
||||
raise ConnectionResetError("Satellite disconnected")
|
||||
|
||||
if AudioChunk.is_type(client_event.type):
|
||||
# Microphone audio
|
||||
chunk = AudioChunk.from_event(client_event)
|
||||
chunk = self._chunk_converter.convert(chunk)
|
||||
self._audio_queue.put_nowait(chunk.audio)
|
||||
else:
|
||||
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
|
||||
|
||||
_LOGGER.debug("Pipeline finished")
|
||||
|
||||
def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None:
|
||||
"""Translate pipeline events into Wyoming events."""
|
||||
assert self._client is not None
|
||||
|
||||
if event.type == assist_pipeline.PipelineEventType.RUN_END:
|
||||
# Pipeline run is complete
|
||||
self._is_pipeline_running = False
|
||||
self.device.set_is_active(False)
|
||||
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
|
||||
self.hass.add_job(self._client.write_event(Detect().event()))
|
||||
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END:
|
||||
# Wake word detection
|
||||
self.device.set_is_active(True)
|
||||
|
||||
# Inform client of wake word detection
|
||||
if event.data and (wake_word_output := event.data.get("wake_word_output")):
|
||||
detection = Detection(
|
||||
name=wake_word_output["wake_word_id"],
|
||||
timestamp=wake_word_output.get("timestamp"),
|
||||
)
|
||||
self.hass.add_job(self._client.write_event(detection.event()))
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_START:
|
||||
# Speech-to-text
|
||||
self.device.set_is_active(True)
|
||||
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
Transcribe(language=event.data["metadata"]["language"]).event()
|
||||
)
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START:
|
||||
# User started speaking
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
VoiceStarted(timestamp=event.data["timestamp"]).event()
|
||||
)
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END:
|
||||
# User stopped speaking
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
VoiceStopped(timestamp=event.data["timestamp"]).event()
|
||||
)
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_END:
|
||||
# Speech-to-text transcript
|
||||
if event.data:
|
||||
# Inform client of transript
|
||||
stt_text = event.data["stt_output"]["text"]
|
||||
self.hass.add_job(
|
||||
self._client.write_event(Transcript(text=stt_text).event())
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.TTS_START:
|
||||
# Text-to-speech text
|
||||
if event.data:
|
||||
# Inform client of text
|
||||
self.hass.add_job(
|
||||
self._client.write_event(
|
||||
Synthesize(
|
||||
text=event.data["tts_input"],
|
||||
voice=SynthesizeVoice(
|
||||
name=event.data.get("voice"),
|
||||
language=event.data.get("language"),
|
||||
),
|
||||
).event()
|
||||
)
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||
# TTS stream
|
||||
if event.data and (tts_output := event.data["tts_output"]):
|
||||
media_id = tts_output["media_id"]
|
||||
self.hass.add_job(self._stream_tts(media_id))
|
||||
|
||||
async def _connect(self) -> None:
|
||||
"""Connect to satellite over TCP."""
|
||||
_LOGGER.debug(
|
||||
"Connecting to satellite at %s:%s", self.service.host, self.service.port
|
||||
)
|
||||
self._client = AsyncTcpClient(self.service.host, self.service.port)
|
||||
await self._client.connect()
|
||||
|
||||
async def _stream_tts(self, media_id: str) -> None:
|
||||
"""Stream TTS WAV audio to satellite in chunks."""
|
||||
assert self._client is not None
|
||||
|
||||
extension, data = await tts.async_get_media_source_audio(self.hass, media_id)
|
||||
if extension != "wav":
|
||||
raise ValueError(f"Cannot stream audio format to satellite: {extension}")
|
||||
|
||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||
sample_rate = wav_file.getframerate()
|
||||
sample_width = wav_file.getsampwidth()
|
||||
sample_channels = wav_file.getnchannels()
|
||||
_LOGGER.debug("Streaming %s TTS sample(s)", wav_file.getnframes())
|
||||
|
||||
timestamp = 0
|
||||
await self._client.write_event(
|
||||
AudioStart(
|
||||
rate=sample_rate,
|
||||
width=sample_width,
|
||||
channels=sample_channels,
|
||||
timestamp=timestamp,
|
||||
).event()
|
||||
)
|
||||
|
||||
# Stream audio chunks
|
||||
while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK):
|
||||
chunk = AudioChunk(
|
||||
rate=sample_rate,
|
||||
width=sample_width,
|
||||
channels=sample_channels,
|
||||
audio=audio_bytes,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
await self._client.write_event(chunk.event())
|
||||
timestamp += chunk.seconds
|
||||
|
||||
await self._client.write_event(AudioStop(timestamp=timestamp).event())
|
||||
_LOGGER.debug("TTS streaming complete")
|
||||
|
||||
async def _stt_stream(self) -> AsyncGenerator[bytes, None]:
|
||||
"""Yield audio chunks from a queue."""
|
||||
is_first_chunk = True
|
||||
while chunk := await self._audio_queue.get():
|
||||
if is_first_chunk:
|
||||
is_first_chunk = False
|
||||
_LOGGER.debug("Receiving audio from satellite")
|
||||
|
||||
yield chunk
|
|
@ -0,0 +1,47 @@
|
|||
"""Select entities for VoIP integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .const import DOMAIN
|
||||
from .devices import SatelliteDevice
|
||||
from .entity import WyomingSatelliteEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import DomainDataItem
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up VoIP switch entities."""
|
||||
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||
|
||||
# Setup is only forwarded for satellites
|
||||
assert item.satellite is not None
|
||||
|
||||
async_add_entities([WyomingSatellitePipelineSelect(hass, item.satellite.device)])
|
||||
|
||||
|
||||
class WyomingSatellitePipelineSelect(WyomingSatelliteEntity, AssistPipelineSelect):
|
||||
"""Pipeline selector for Wyoming satellites."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant, device: SatelliteDevice) -> None:
|
||||
"""Initialize a pipeline selector."""
|
||||
self.device = device
|
||||
|
||||
WyomingSatelliteEntity.__init__(self, device)
|
||||
AssistPipelineSelect.__init__(self, hass, device.satellite_id)
|
||||
|
||||
async def async_select_option(self, option: str) -> None:
|
||||
"""Select an option."""
|
||||
await super().async_select_option(option)
|
||||
self.device.set_pipeline_name(option)
|
|
@ -9,6 +9,10 @@
|
|||
},
|
||||
"hassio_confirm": {
|
||||
"description": "Do you want to configure Home Assistant to connect to the Wyoming service provided by the add-on: {addon}?"
|
||||
},
|
||||
"zeroconf_confirm": {
|
||||
"description": "Do you want to configure Home Assistant to connect to the Wyoming service {name}?",
|
||||
"title": "Discovered Wyoming service"
|
||||
}
|
||||
},
|
||||
"error": {
|
||||
|
@ -16,7 +20,31 @@
|
|||
},
|
||||
"abort": {
|
||||
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]",
|
||||
"no_services": "No services found at endpoint"
|
||||
"no_services": "No services found at endpoint",
|
||||
"no_port": "No port for endpoint"
|
||||
}
|
||||
},
|
||||
"entity": {
|
||||
"binary_sensor": {
|
||||
"assist_in_progress": {
|
||||
"name": "[%key:component::assist_pipeline::entity::binary_sensor::assist_in_progress::name%]"
|
||||
}
|
||||
},
|
||||
"select": {
|
||||
"pipeline": {
|
||||
"name": "[%key:component::assist_pipeline::entity::select::pipeline::name%]",
|
||||
"state": {
|
||||
"preferred": "[%key:component::assist_pipeline::entity::select::pipeline::state::preferred%]"
|
||||
}
|
||||
},
|
||||
"noise_suppression": {
|
||||
"name": "Noise suppression"
|
||||
}
|
||||
},
|
||||
"switch": {
|
||||
"satellite_enabled": {
|
||||
"name": "Satellite enabled"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
|||
from .const import DOMAIN, SAMPLE_CHANNELS, SAMPLE_RATE, SAMPLE_WIDTH
|
||||
from .data import WyomingService
|
||||
from .error import WyomingError
|
||||
from .models import DomainDataItem
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -24,10 +25,10 @@ async def async_setup_entry(
|
|||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up Wyoming speech-to-text."""
|
||||
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
|
||||
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||
async_add_entities(
|
||||
[
|
||||
WyomingSttProvider(config_entry, service),
|
||||
WyomingSttProvider(config_entry, item.service),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
"""Wyoming switch entities."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from homeassistant.components.switch import SwitchEntity, SwitchEntityDescription
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import STATE_ON, EntityCategory
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import restore_state
|
||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||
|
||||
from .const import DOMAIN
|
||||
from .entity import WyomingSatelliteEntity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .models import DomainDataItem
|
||||
|
||||
|
||||
async def async_setup_entry(
|
||||
hass: HomeAssistant,
|
||||
config_entry: ConfigEntry,
|
||||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up VoIP switch entities."""
|
||||
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||
|
||||
# Setup is only forwarded for satellites
|
||||
assert item.satellite is not None
|
||||
|
||||
async_add_entities([WyomingSatelliteEnabledSwitch(item.satellite.device)])
|
||||
|
||||
|
||||
class WyomingSatelliteEnabledSwitch(
|
||||
WyomingSatelliteEntity, restore_state.RestoreEntity, SwitchEntity
|
||||
):
|
||||
"""Entity to represent if satellite is enabled."""
|
||||
|
||||
entity_description = SwitchEntityDescription(
|
||||
key="satellite_enabled",
|
||||
translation_key="satellite_enabled",
|
||||
entity_category=EntityCategory.CONFIG,
|
||||
)
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
"""Call when entity about to be added to hass."""
|
||||
await super().async_added_to_hass()
|
||||
|
||||
state = await self.async_get_last_state()
|
||||
|
||||
# Default to on
|
||||
self._attr_is_on = (state is None) or (state.state == STATE_ON)
|
||||
|
||||
async def async_turn_on(self, **kwargs: Any) -> None:
|
||||
"""Turn on."""
|
||||
self._attr_is_on = True
|
||||
self.async_write_ha_state()
|
||||
self._device.set_is_enabled(True)
|
||||
|
||||
async def async_turn_off(self, **kwargs: Any) -> None:
|
||||
"""Turn off."""
|
||||
self._attr_is_on = False
|
||||
self.async_write_ha_state()
|
||||
self._device.set_is_enabled(False)
|
|
@ -16,6 +16,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
|||
from .const import ATTR_SPEAKER, DOMAIN
|
||||
from .data import WyomingService
|
||||
from .error import WyomingError
|
||||
from .models import DomainDataItem
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -26,10 +27,10 @@ async def async_setup_entry(
|
|||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up Wyoming speech-to-text."""
|
||||
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
|
||||
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||
async_add_entities(
|
||||
[
|
||||
WyomingTtsProvider(config_entry, service),
|
||||
WyomingTtsProvider(config_entry, item.service),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
|||
from .const import DOMAIN
|
||||
from .data import WyomingService, load_wyoming_info
|
||||
from .error import WyomingError
|
||||
from .models import DomainDataItem
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -25,10 +26,10 @@ async def async_setup_entry(
|
|||
async_add_entities: AddEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up Wyoming wake-word-detection."""
|
||||
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id]
|
||||
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
|
||||
async_add_entities(
|
||||
[
|
||||
WyomingWakeWordProvider(hass, config_entry, service),
|
||||
WyomingWakeWordProvider(hass, config_entry, item.service),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -715,6 +715,11 @@ ZEROCONF = {
|
|||
"domain": "wled",
|
||||
},
|
||||
],
|
||||
"_wyoming._tcp.local.": [
|
||||
{
|
||||
"domain": "wyoming",
|
||||
},
|
||||
],
|
||||
"_xbmc-jsonrpc-h._tcp.local.": [
|
||||
{
|
||||
"domain": "kodi",
|
||||
|
|
|
@ -2750,7 +2750,7 @@ wled==0.17.0
|
|||
wolf-smartset==0.1.11
|
||||
|
||||
# homeassistant.components.wyoming
|
||||
wyoming==1.2.0
|
||||
wyoming==1.3.0
|
||||
|
||||
# homeassistant.components.xbox
|
||||
xbox-webapi==2.0.11
|
||||
|
|
|
@ -2054,7 +2054,7 @@ wled==0.17.0
|
|||
wolf-smartset==0.1.11
|
||||
|
||||
# homeassistant.components.wyoming
|
||||
wyoming==1.2.0
|
||||
wyoming==1.3.0
|
||||
|
||||
# homeassistant.components.xbox
|
||||
xbox-webapi==2.0.11
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
"""Tests for the Wyoming integration."""
|
||||
import asyncio
|
||||
|
||||
from wyoming.event import Event
|
||||
from wyoming.info import (
|
||||
AsrModel,
|
||||
AsrProgram,
|
||||
Attribution,
|
||||
Info,
|
||||
Satellite,
|
||||
TtsProgram,
|
||||
TtsVoice,
|
||||
TtsVoiceSpeaker,
|
||||
|
@ -72,24 +74,36 @@ WAKE_WORD_INFO = Info(
|
|||
)
|
||||
]
|
||||
)
|
||||
SATELLITE_INFO = Info(
|
||||
satellite=Satellite(
|
||||
name="Test Satellite",
|
||||
description="Test Satellite",
|
||||
installed=True,
|
||||
attribution=TEST_ATTR,
|
||||
area="Office",
|
||||
)
|
||||
)
|
||||
EMPTY_INFO = Info()
|
||||
|
||||
|
||||
class MockAsyncTcpClient:
|
||||
"""Mock AsyncTcpClient."""
|
||||
|
||||
def __init__(self, responses) -> None:
|
||||
def __init__(self, responses: list[Event]) -> None:
|
||||
"""Initialize."""
|
||||
self.host = None
|
||||
self.port = None
|
||||
self.written = []
|
||||
self.host: str | None = None
|
||||
self.port: int | None = None
|
||||
self.written: list[Event] = []
|
||||
self.responses = responses
|
||||
|
||||
async def write_event(self, event):
|
||||
async def connect(self) -> None:
|
||||
"""Connect."""
|
||||
|
||||
async def write_event(self, event: Event):
|
||||
"""Send."""
|
||||
self.written.append(event)
|
||||
|
||||
async def read_event(self):
|
||||
async def read_event(self) -> Event | None:
|
||||
"""Receive."""
|
||||
await asyncio.sleep(0) # force context switch
|
||||
|
||||
|
@ -105,7 +119,7 @@ class MockAsyncTcpClient:
|
|||
async def __aexit__(self, exc_type, exc, tb):
|
||||
"""Exit."""
|
||||
|
||||
def __call__(self, host, port):
|
||||
def __call__(self, host: str, port: int):
|
||||
"""Call."""
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
|
|
@ -5,14 +5,23 @@ from unittest.mock import AsyncMock, patch
|
|||
import pytest
|
||||
|
||||
from homeassistant.components import stt
|
||||
from homeassistant.components.wyoming import DOMAIN
|
||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import STT_INFO, TTS_INFO, WAKE_WORD_INFO
|
||||
from . import SATELLITE_INFO, STT_INFO, TTS_INFO, WAKE_WORD_INFO
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def init_components(hass: HomeAssistant):
|
||||
"""Set up required components."""
|
||||
assert await async_setup_component(hass, "homeassistant", {})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_setup_entry() -> Generator[AsyncMock, None, None]:
|
||||
"""Override async_setup_entry."""
|
||||
|
@ -110,3 +119,39 @@ def metadata(hass: HomeAssistant) -> stt.SpeechMetadata:
|
|||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def satellite_config_entry(hass: HomeAssistant) -> ConfigEntry:
|
||||
"""Create a config entry."""
|
||||
entry = MockConfigEntry(
|
||||
domain="wyoming",
|
||||
data={
|
||||
"host": "1.2.3.4",
|
||||
"port": 1234,
|
||||
},
|
||||
title="Test Satellite",
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
return entry
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def init_satellite(hass: HomeAssistant, satellite_config_entry: ConfigEntry):
|
||||
"""Initialize Wyoming satellite."""
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.run"
|
||||
) as _run_mock:
|
||||
# _run_mock: satellite task does not actually run
|
||||
await hass.config_entries.async_setup(satellite_config_entry.entry_id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def satellite_device(
|
||||
hass: HomeAssistant, init_satellite, satellite_config_entry: ConfigEntry
|
||||
) -> SatelliteDevice:
|
||||
"""Get a satellite device fixture."""
|
||||
return hass.data[DOMAIN][satellite_config_entry.entry_id].satellite.device
|
||||
|
|
|
@ -121,3 +121,45 @@
|
|||
'version': 1,
|
||||
})
|
||||
# ---
|
||||
# name: test_zeroconf_discovery
|
||||
FlowResultSnapshot({
|
||||
'context': dict({
|
||||
'name': 'Test Satellite',
|
||||
'source': 'zeroconf',
|
||||
'title_placeholders': dict({
|
||||
'name': 'Test Satellite',
|
||||
}),
|
||||
'unique_id': 'test_zeroconf_name._wyoming._tcp.local._Test Satellite',
|
||||
}),
|
||||
'data': dict({
|
||||
'host': '127.0.0.1',
|
||||
'port': 12345,
|
||||
}),
|
||||
'description': None,
|
||||
'description_placeholders': None,
|
||||
'flow_id': <ANY>,
|
||||
'handler': 'wyoming',
|
||||
'options': dict({
|
||||
}),
|
||||
'result': ConfigEntrySnapshot({
|
||||
'data': dict({
|
||||
'host': '127.0.0.1',
|
||||
'port': 12345,
|
||||
}),
|
||||
'disabled_by': None,
|
||||
'domain': 'wyoming',
|
||||
'entry_id': <ANY>,
|
||||
'options': dict({
|
||||
}),
|
||||
'pref_disable_new_entities': False,
|
||||
'pref_disable_polling': False,
|
||||
'source': 'zeroconf',
|
||||
'title': 'Test Satellite',
|
||||
'unique_id': 'test_zeroconf_name._wyoming._tcp.local._Test Satellite',
|
||||
'version': 1,
|
||||
}),
|
||||
'title': 'Test Satellite',
|
||||
'type': <FlowResultType.CREATE_ENTRY: 'create_entry'>,
|
||||
'version': 1,
|
||||
})
|
||||
# ---
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
"""Test Wyoming binary sensor devices."""
|
||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import STATE_OFF, STATE_ON
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
||||
async def test_assist_in_progress(
|
||||
hass: HomeAssistant,
|
||||
satellite_config_entry: ConfigEntry,
|
||||
satellite_device: SatelliteDevice,
|
||||
) -> None:
|
||||
"""Test assist in progress."""
|
||||
assist_in_progress_id = satellite_device.get_assist_in_progress_entity_id(hass)
|
||||
assert assist_in_progress_id
|
||||
|
||||
state = hass.states.get(assist_in_progress_id)
|
||||
assert state is not None
|
||||
assert state.state == STATE_OFF
|
||||
assert not satellite_device.is_active
|
||||
|
||||
satellite_device.set_is_active(True)
|
||||
|
||||
state = hass.states.get(assist_in_progress_id)
|
||||
assert state is not None
|
||||
assert state.state == STATE_ON
|
||||
assert satellite_device.is_active
|
||||
|
||||
satellite_device.set_is_active(False)
|
||||
|
||||
state = hass.states.get(assist_in_progress_id)
|
||||
assert state is not None
|
||||
assert state.state == STATE_OFF
|
||||
assert not satellite_device.is_active
|
|
@ -1,4 +1,5 @@
|
|||
"""Test the Wyoming config flow."""
|
||||
from ipaddress import IPv4Address
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
@ -8,10 +9,11 @@ from wyoming.info import Info
|
|||
from homeassistant import config_entries
|
||||
from homeassistant.components.hassio import HassioServiceInfo
|
||||
from homeassistant.components.wyoming.const import DOMAIN
|
||||
from homeassistant.components.zeroconf import ZeroconfServiceInfo
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.data_entry_flow import FlowResultType
|
||||
|
||||
from . import EMPTY_INFO, STT_INFO, TTS_INFO
|
||||
from . import EMPTY_INFO, SATELLITE_INFO, STT_INFO, TTS_INFO
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
@ -25,6 +27,16 @@ ADDON_DISCOVERY = HassioServiceInfo(
|
|||
uuid="1234",
|
||||
)
|
||||
|
||||
ZEROCONF_DISCOVERY = ZeroconfServiceInfo(
|
||||
ip_address=IPv4Address("127.0.0.1"),
|
||||
ip_addresses=[IPv4Address("127.0.0.1")],
|
||||
port=12345,
|
||||
hostname="localhost",
|
||||
type="_wyoming._tcp.local.",
|
||||
name="test_zeroconf_name._wyoming._tcp.local.",
|
||||
properties={},
|
||||
)
|
||||
|
||||
pytestmark = pytest.mark.usefixtures("mock_setup_entry")
|
||||
|
||||
|
||||
|
@ -214,3 +226,70 @@ async def test_hassio_addon_no_supported_services(hass: HomeAssistant) -> None:
|
|||
|
||||
assert result2.get("type") == FlowResultType.ABORT
|
||||
assert result2.get("reason") == "no_services"
|
||||
|
||||
|
||||
async def test_zeroconf_discovery(
|
||||
hass: HomeAssistant,
|
||||
mock_setup_entry: AsyncMock,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test config flow initiated by Supervisor."""
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
):
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
data=ZEROCONF_DISCOVERY,
|
||||
context={"source": config_entries.SOURCE_ZEROCONF},
|
||||
)
|
||||
|
||||
assert result.get("type") == FlowResultType.FORM
|
||||
assert result.get("step_id") == "zeroconf_confirm"
|
||||
assert result.get("description_placeholders") == {
|
||||
"name": SATELLITE_INFO.satellite.name
|
||||
}
|
||||
|
||||
result2 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
|
||||
assert result2.get("type") == FlowResultType.CREATE_ENTRY
|
||||
assert result2 == snapshot
|
||||
|
||||
|
||||
async def test_zeroconf_discovery_no_port(
|
||||
hass: HomeAssistant,
|
||||
mock_setup_entry: AsyncMock,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test discovery when the zeroconf service does not have a port."""
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch.object(ZEROCONF_DISCOVERY, "port", None):
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
data=ZEROCONF_DISCOVERY,
|
||||
context={"source": config_entries.SOURCE_ZEROCONF},
|
||||
)
|
||||
|
||||
assert result.get("type") == FlowResultType.ABORT
|
||||
assert result.get("reason") == "no_port"
|
||||
|
||||
|
||||
async def test_zeroconf_discovery_no_services(
|
||||
hass: HomeAssistant,
|
||||
mock_setup_entry: AsyncMock,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test discovery when there are no supported services on the client."""
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=Info(),
|
||||
):
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
data=ZEROCONF_DISCOVERY,
|
||||
context={"source": config_entries.SOURCE_ZEROCONF},
|
||||
)
|
||||
|
||||
assert result.get("type") == FlowResultType.ABORT
|
||||
assert result.get("reason") == "no_services"
|
||||
|
|
|
@ -3,13 +3,15 @@ from __future__ import annotations
|
|||
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.components.wyoming.data import load_wyoming_info
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components.wyoming.data import WyomingService, load_wyoming_info
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from . import STT_INFO, MockAsyncTcpClient
|
||||
from . import SATELLITE_INFO, STT_INFO, TTS_INFO, WAKE_WORD_INFO, MockAsyncTcpClient
|
||||
|
||||
|
||||
async def test_load_info(hass: HomeAssistant, snapshot) -> None:
|
||||
async def test_load_info(hass: HomeAssistant, snapshot: SnapshotAssertion) -> None:
|
||||
"""Test loading info."""
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.AsyncTcpClient",
|
||||
|
@ -38,3 +40,38 @@ async def test_load_info_oserror(hass: HomeAssistant) -> None:
|
|||
)
|
||||
|
||||
assert info is None
|
||||
|
||||
|
||||
async def test_service_name(hass: HomeAssistant) -> None:
|
||||
"""Test loading service info."""
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.AsyncTcpClient",
|
||||
MockAsyncTcpClient([STT_INFO.event()]),
|
||||
):
|
||||
service = await WyomingService.create("localhost", 1234)
|
||||
assert service is not None
|
||||
assert service.get_name() == STT_INFO.asr[0].name
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.AsyncTcpClient",
|
||||
MockAsyncTcpClient([TTS_INFO.event()]),
|
||||
):
|
||||
service = await WyomingService.create("localhost", 1234)
|
||||
assert service is not None
|
||||
assert service.get_name() == TTS_INFO.tts[0].name
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.AsyncTcpClient",
|
||||
MockAsyncTcpClient([WAKE_WORD_INFO.event()]),
|
||||
):
|
||||
service = await WyomingService.create("localhost", 1234)
|
||||
assert service is not None
|
||||
assert service.get_name() == WAKE_WORD_INFO.wake[0].name
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.AsyncTcpClient",
|
||||
MockAsyncTcpClient([SATELLITE_INFO.event()]),
|
||||
):
|
||||
service = await WyomingService.create("localhost", 1234)
|
||||
assert service is not None
|
||||
assert service.get_name() == SATELLITE_INFO.satellite.name
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
"""Test Wyoming devices."""
|
||||
from __future__ import annotations
|
||||
|
||||
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
|
||||
from homeassistant.components.wyoming import DOMAIN
|
||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import STATE_OFF, STATE_ON
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import device_registry as dr
|
||||
|
||||
|
||||
async def test_device_registry_info(
|
||||
hass: HomeAssistant,
|
||||
satellite_device: SatelliteDevice,
|
||||
satellite_config_entry: ConfigEntry,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
) -> None:
|
||||
"""Test info in device registry."""
|
||||
|
||||
# Satellite uses config entry id since only one satellite per entry is
|
||||
# supported.
|
||||
device = device_registry.async_get_device(
|
||||
identifiers={(DOMAIN, satellite_config_entry.entry_id)}
|
||||
)
|
||||
assert device is not None
|
||||
assert device.name == "Test Satellite"
|
||||
assert device.suggested_area == "Office"
|
||||
|
||||
# Check associated entities
|
||||
assist_in_progress_id = satellite_device.get_assist_in_progress_entity_id(hass)
|
||||
assert assist_in_progress_id
|
||||
assist_in_progress_state = hass.states.get(assist_in_progress_id)
|
||||
assert assist_in_progress_state is not None
|
||||
assert assist_in_progress_state.state == STATE_OFF
|
||||
|
||||
satellite_enabled_id = satellite_device.get_satellite_enabled_entity_id(hass)
|
||||
assert satellite_enabled_id
|
||||
satellite_enabled_state = hass.states.get(satellite_enabled_id)
|
||||
assert satellite_enabled_state is not None
|
||||
assert satellite_enabled_state.state == STATE_ON
|
||||
|
||||
pipeline_entity_id = satellite_device.get_pipeline_entity_id(hass)
|
||||
assert pipeline_entity_id
|
||||
pipeline_state = hass.states.get(pipeline_entity_id)
|
||||
assert pipeline_state is not None
|
||||
assert pipeline_state.state == OPTION_PREFERRED
|
||||
|
||||
|
||||
async def test_remove_device_registry_entry(
|
||||
hass: HomeAssistant,
|
||||
satellite_device: SatelliteDevice,
|
||||
device_registry: dr.DeviceRegistry,
|
||||
) -> None:
|
||||
"""Test removing a device registry entry."""
|
||||
|
||||
# Check associated entities
|
||||
assist_in_progress_id = satellite_device.get_assist_in_progress_entity_id(hass)
|
||||
assert assist_in_progress_id
|
||||
assert hass.states.get(assist_in_progress_id) is not None
|
||||
|
||||
satellite_enabled_id = satellite_device.get_satellite_enabled_entity_id(hass)
|
||||
assert satellite_enabled_id
|
||||
assert hass.states.get(satellite_enabled_id) is not None
|
||||
|
||||
pipeline_entity_id = satellite_device.get_pipeline_entity_id(hass)
|
||||
assert pipeline_entity_id
|
||||
assert hass.states.get(pipeline_entity_id) is not None
|
||||
|
||||
# Remove
|
||||
device_registry.async_remove_device(satellite_device.device_id)
|
||||
await hass.async_block_till_done()
|
||||
await hass.async_block_till_done()
|
||||
|
||||
# Everything should be gone
|
||||
assert hass.states.get(assist_in_progress_id) is None
|
||||
assert hass.states.get(satellite_enabled_id) is None
|
||||
assert hass.states.get(pipeline_entity_id) is None
|
|
@ -0,0 +1,460 @@
|
|||
"""Test Wyoming satellite."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
from unittest.mock import patch
|
||||
import wave
|
||||
|
||||
from wyoming.asr import Transcribe, Transcript
|
||||
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
||||
from wyoming.event import Event
|
||||
from wyoming.pipeline import PipelineStage, RunPipeline
|
||||
from wyoming.satellite import RunSatellite
|
||||
from wyoming.tts import Synthesize
|
||||
from wyoming.vad import VoiceStarted, VoiceStopped
|
||||
from wyoming.wake import Detect, Detection
|
||||
|
||||
from homeassistant.components import assist_pipeline, wyoming
|
||||
from homeassistant.components.wyoming.data import WyomingService
|
||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import SATELLITE_INFO, MockAsyncTcpClient
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
|
||||
async def setup_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||
"""Set up config entry for Wyoming satellite.
|
||||
|
||||
This is separated from the satellite_config_entry method in conftest.py so
|
||||
we can patch functions before the satellite task is run during setup.
|
||||
"""
|
||||
entry = MockConfigEntry(
|
||||
domain="wyoming",
|
||||
data={
|
||||
"host": "1.2.3.4",
|
||||
"port": 1234,
|
||||
},
|
||||
title="Test Satellite",
|
||||
)
|
||||
entry.add_to_hass(hass)
|
||||
await hass.config_entries.async_setup(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
return entry
|
||||
|
||||
|
||||
def get_test_wav() -> bytes:
|
||||
"""Get bytes for test WAV file."""
|
||||
with io.BytesIO() as wav_io:
|
||||
with wave.open(wav_io, "wb") as wav_file:
|
||||
wav_file.setframerate(22050)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setnchannels(1)
|
||||
|
||||
# Single frame
|
||||
wav_file.writeframes(b"123")
|
||||
|
||||
return wav_io.getvalue()
|
||||
|
||||
|
||||
class SatelliteAsyncTcpClient(MockAsyncTcpClient):
|
||||
"""Satellite AsyncTcpClient."""
|
||||
|
||||
def __init__(self, responses: list[Event]) -> None:
|
||||
"""Initialize client."""
|
||||
super().__init__(responses)
|
||||
|
||||
self.connect_event = asyncio.Event()
|
||||
self.run_satellite_event = asyncio.Event()
|
||||
self.detect_event = asyncio.Event()
|
||||
|
||||
self.detection_event = asyncio.Event()
|
||||
self.detection: Detection | None = None
|
||||
|
||||
self.transcribe_event = asyncio.Event()
|
||||
self.transcribe: Transcribe | None = None
|
||||
|
||||
self.voice_started_event = asyncio.Event()
|
||||
self.voice_started: VoiceStarted | None = None
|
||||
|
||||
self.voice_stopped_event = asyncio.Event()
|
||||
self.voice_stopped: VoiceStopped | None = None
|
||||
|
||||
self.transcript_event = asyncio.Event()
|
||||
self.transcript: Transcript | None = None
|
||||
|
||||
self.synthesize_event = asyncio.Event()
|
||||
self.synthesize: Synthesize | None = None
|
||||
|
||||
self.tts_audio_start_event = asyncio.Event()
|
||||
self.tts_audio_chunk_event = asyncio.Event()
|
||||
self.tts_audio_stop_event = asyncio.Event()
|
||||
self.tts_audio_chunk: AudioChunk | None = None
|
||||
|
||||
self._mic_audio_chunk = AudioChunk(
|
||||
rate=16000, width=2, channels=1, audio=b"chunk"
|
||||
).event()
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect."""
|
||||
self.connect_event.set()
|
||||
|
||||
async def write_event(self, event: Event):
|
||||
"""Send."""
|
||||
if RunSatellite.is_type(event.type):
|
||||
self.run_satellite_event.set()
|
||||
elif Detect.is_type(event.type):
|
||||
self.detect_event.set()
|
||||
elif Detection.is_type(event.type):
|
||||
self.detection = Detection.from_event(event)
|
||||
self.detection_event.set()
|
||||
elif Transcribe.is_type(event.type):
|
||||
self.transcribe = Transcribe.from_event(event)
|
||||
self.transcribe_event.set()
|
||||
elif VoiceStarted.is_type(event.type):
|
||||
self.voice_started = VoiceStarted.from_event(event)
|
||||
self.voice_started_event.set()
|
||||
elif VoiceStopped.is_type(event.type):
|
||||
self.voice_stopped = VoiceStopped.from_event(event)
|
||||
self.voice_stopped_event.set()
|
||||
elif Transcript.is_type(event.type):
|
||||
self.transcript = Transcript.from_event(event)
|
||||
self.transcript_event.set()
|
||||
elif Synthesize.is_type(event.type):
|
||||
self.synthesize = Synthesize.from_event(event)
|
||||
self.synthesize_event.set()
|
||||
elif AudioStart.is_type(event.type):
|
||||
self.tts_audio_start_event.set()
|
||||
elif AudioChunk.is_type(event.type):
|
||||
self.tts_audio_chunk = AudioChunk.from_event(event)
|
||||
self.tts_audio_chunk_event.set()
|
||||
elif AudioStop.is_type(event.type):
|
||||
self.tts_audio_stop_event.set()
|
||||
|
||||
async def read_event(self) -> Event | None:
|
||||
"""Receive."""
|
||||
event = await super().read_event()
|
||||
|
||||
# Keep sending audio chunks instead of None
|
||||
return event or self._mic_audio_chunk
|
||||
|
||||
|
||||
async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
||||
"""Test running a pipeline with a satellite."""
|
||||
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
||||
|
||||
events = [
|
||||
RunPipeline(
|
||||
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
||||
).event(),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
SatelliteAsyncTcpClient(events),
|
||||
) as mock_client, patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
) as mock_run_pipeline, patch(
|
||||
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
|
||||
return_value=("wav", get_test_wav()),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||
entry.entry_id
|
||||
].satellite.device
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.connect_event.wait()
|
||||
await mock_client.run_satellite_event.wait()
|
||||
|
||||
mock_run_pipeline.assert_called()
|
||||
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
|
||||
|
||||
# Start detecting wake word
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.WAKE_WORD_START
|
||||
)
|
||||
)
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.detect_event.wait()
|
||||
|
||||
assert not device.is_active
|
||||
assert device.is_enabled
|
||||
|
||||
# Wake word is detected
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.WAKE_WORD_END,
|
||||
{"wake_word_output": {"wake_word_id": "test_wake_word"}},
|
||||
)
|
||||
)
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.detection_event.wait()
|
||||
|
||||
assert mock_client.detection is not None
|
||||
assert mock_client.detection.name == "test_wake_word"
|
||||
|
||||
# "Assist in progress" sensor should be active now
|
||||
assert device.is_active
|
||||
|
||||
# Speech-to-text started
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.STT_START,
|
||||
{"metadata": {"language": "en"}},
|
||||
)
|
||||
)
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.transcribe_event.wait()
|
||||
|
||||
assert mock_client.transcribe is not None
|
||||
assert mock_client.transcribe.language == "en"
|
||||
|
||||
# User started speaking
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.STT_VAD_START, {"timestamp": 1234}
|
||||
)
|
||||
)
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.voice_started_event.wait()
|
||||
|
||||
assert mock_client.voice_started is not None
|
||||
assert mock_client.voice_started.timestamp == 1234
|
||||
|
||||
# User stopped speaking
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.STT_VAD_END, {"timestamp": 5678}
|
||||
)
|
||||
)
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.voice_stopped_event.wait()
|
||||
|
||||
assert mock_client.voice_stopped is not None
|
||||
assert mock_client.voice_stopped.timestamp == 5678
|
||||
|
||||
# Speech-to-text transcription
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.STT_END,
|
||||
{"stt_output": {"text": "test transcript"}},
|
||||
)
|
||||
)
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.transcript_event.wait()
|
||||
|
||||
assert mock_client.transcript is not None
|
||||
assert mock_client.transcript.text == "test transcript"
|
||||
|
||||
# Text-to-speech text
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.TTS_START,
|
||||
{
|
||||
"tts_input": "test text to speak",
|
||||
"voice": "test voice",
|
||||
},
|
||||
)
|
||||
)
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.synthesize_event.wait()
|
||||
|
||||
assert mock_client.synthesize is not None
|
||||
assert mock_client.synthesize.text == "test text to speak"
|
||||
assert mock_client.synthesize.voice is not None
|
||||
assert mock_client.synthesize.voice.name == "test voice"
|
||||
|
||||
# Text-to-speech media
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.TTS_END,
|
||||
{"tts_output": {"media_id": "test media id"}},
|
||||
)
|
||||
)
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.tts_audio_start_event.wait()
|
||||
await mock_client.tts_audio_chunk_event.wait()
|
||||
await mock_client.tts_audio_stop_event.wait()
|
||||
|
||||
# Verify audio chunk from test WAV
|
||||
assert mock_client.tts_audio_chunk is not None
|
||||
assert mock_client.tts_audio_chunk.rate == 22050
|
||||
assert mock_client.tts_audio_chunk.width == 2
|
||||
assert mock_client.tts_audio_chunk.channels == 1
|
||||
assert mock_client.tts_audio_chunk.audio == b"123"
|
||||
|
||||
# Pipeline finished
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
|
||||
)
|
||||
assert not device.is_active
|
||||
|
||||
# Stop the satellite
|
||||
await hass.config_entries.async_unload(entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
|
||||
async def test_satellite_disabled(hass: HomeAssistant) -> None:
|
||||
"""Test callback for a satellite that has been disabled."""
|
||||
on_disabled_event = asyncio.Event()
|
||||
|
||||
original_make_satellite = wyoming._make_satellite
|
||||
|
||||
def make_disabled_satellite(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService
|
||||
):
|
||||
satellite = original_make_satellite(hass, config_entry, service)
|
||||
satellite.device.is_enabled = False
|
||||
|
||||
return satellite
|
||||
|
||||
async def on_disabled(self):
|
||||
on_disabled_event.set()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
"homeassistant.components.wyoming._make_satellite", make_disabled_satellite
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_disabled",
|
||||
on_disabled,
|
||||
):
|
||||
await setup_config_entry(hass)
|
||||
async with asyncio.timeout(1):
|
||||
await on_disabled_event.wait()
|
||||
|
||||
|
||||
async def test_satellite_restart(hass: HomeAssistant) -> None:
|
||||
"""Test pipeline loop restart after unexpected error."""
|
||||
on_restart_event = asyncio.Event()
|
||||
|
||||
async def on_restart(self):
|
||||
self.stop()
|
||||
on_restart_event.set()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite._run_once",
|
||||
side_effect=RuntimeError(),
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
||||
on_restart,
|
||||
):
|
||||
await setup_config_entry(hass)
|
||||
async with asyncio.timeout(1):
|
||||
await on_restart_event.wait()
|
||||
|
||||
|
||||
async def test_satellite_reconnect(hass: HomeAssistant) -> None:
|
||||
"""Test satellite reconnect call after connection refused."""
|
||||
on_reconnect_event = asyncio.Event()
|
||||
|
||||
async def on_reconnect(self):
|
||||
self.stop()
|
||||
on_reconnect_event.set()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient.connect",
|
||||
side_effect=ConnectionRefusedError(),
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect",
|
||||
on_reconnect,
|
||||
):
|
||||
await setup_config_entry(hass)
|
||||
async with asyncio.timeout(1):
|
||||
await on_reconnect_event.wait()
|
||||
|
||||
|
||||
async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None:
|
||||
"""Test satellite disconnecting before pipeline run."""
|
||||
on_restart_event = asyncio.Event()
|
||||
|
||||
async def on_restart(self):
|
||||
self.stop()
|
||||
on_restart_event.set()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
MockAsyncTcpClient([]), # no RunPipeline event
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
) as mock_run_pipeline, patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
||||
on_restart,
|
||||
):
|
||||
await setup_config_entry(hass)
|
||||
async with asyncio.timeout(1):
|
||||
await on_restart_event.wait()
|
||||
|
||||
# Pipeline should never have run
|
||||
mock_run_pipeline.assert_not_called()
|
||||
|
||||
|
||||
async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None:
|
||||
"""Test satellite disconnecting during pipeline run."""
|
||||
events = [
|
||||
RunPipeline(
|
||||
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
||||
).event(),
|
||||
] # no audio chunks after RunPipeline
|
||||
|
||||
on_restart_event = asyncio.Event()
|
||||
on_stopped_event = asyncio.Event()
|
||||
|
||||
async def on_restart(self):
|
||||
# Pretend sensor got stuck on
|
||||
self.device.is_active = True
|
||||
self.stop()
|
||||
on_restart_event.set()
|
||||
|
||||
async def on_stopped(self):
|
||||
on_stopped_event.set()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.wyoming.data.load_wyoming_info",
|
||||
return_value=SATELLITE_INFO,
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
||||
MockAsyncTcpClient(events),
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
||||
) as mock_run_pipeline, patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
||||
on_restart,
|
||||
), patch(
|
||||
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
|
||||
on_stopped,
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
||||
entry.entry_id
|
||||
].satellite.device
|
||||
|
||||
async with asyncio.timeout(1):
|
||||
await on_restart_event.wait()
|
||||
await on_stopped_event.wait()
|
||||
|
||||
# Pipeline should have run once
|
||||
mock_run_pipeline.assert_called_once()
|
||||
|
||||
# Sensor should have been turned off
|
||||
assert not device.is_active
|
|
@ -0,0 +1,83 @@
|
|||
"""Test Wyoming select."""
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from homeassistant.components import assist_pipeline
|
||||
from homeassistant.components.assist_pipeline.pipeline import PipelineData
|
||||
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
|
||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
|
||||
async def test_pipeline_select(
|
||||
hass: HomeAssistant,
|
||||
satellite_config_entry: ConfigEntry,
|
||||
satellite_device: SatelliteDevice,
|
||||
) -> None:
|
||||
"""Test pipeline select.
|
||||
|
||||
Functionality is tested in assist_pipeline/test_select.py.
|
||||
This test is only to ensure it is set up.
|
||||
"""
|
||||
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
||||
pipeline_data: PipelineData = hass.data[assist_pipeline.DOMAIN]
|
||||
|
||||
# Create second pipeline
|
||||
await pipeline_data.pipeline_store.async_create_item(
|
||||
{
|
||||
"name": "Test 1",
|
||||
"language": "en-US",
|
||||
"conversation_engine": None,
|
||||
"conversation_language": "en-US",
|
||||
"tts_engine": None,
|
||||
"tts_language": None,
|
||||
"tts_voice": None,
|
||||
"stt_engine": None,
|
||||
"stt_language": None,
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
}
|
||||
)
|
||||
|
||||
# Preferred pipeline is the default
|
||||
pipeline_entity_id = satellite_device.get_pipeline_entity_id(hass)
|
||||
assert pipeline_entity_id
|
||||
|
||||
state = hass.states.get(pipeline_entity_id)
|
||||
assert state is not None
|
||||
assert state.state == OPTION_PREFERRED
|
||||
|
||||
# Change to second pipeline
|
||||
with patch.object(satellite_device, "set_pipeline_name") as mock_pipeline_changed:
|
||||
await hass.services.async_call(
|
||||
"select",
|
||||
"select_option",
|
||||
{"entity_id": pipeline_entity_id, "option": "Test 1"},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
state = hass.states.get(pipeline_entity_id)
|
||||
assert state is not None
|
||||
assert state.state == "Test 1"
|
||||
|
||||
# async_pipeline_changed should have been called
|
||||
mock_pipeline_changed.assert_called_once_with("Test 1")
|
||||
|
||||
# Change back and check update listener
|
||||
pipeline_listener = Mock()
|
||||
satellite_device.set_pipeline_listener(pipeline_listener)
|
||||
|
||||
await hass.services.async_call(
|
||||
"select",
|
||||
"select_option",
|
||||
{"entity_id": pipeline_entity_id, "option": OPTION_PREFERRED},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
state = hass.states.get(pipeline_entity_id)
|
||||
assert state is not None
|
||||
assert state.state == OPTION_PREFERRED
|
||||
|
||||
# listener should have been called
|
||||
pipeline_listener.assert_called_once()
|
|
@ -0,0 +1,32 @@
|
|||
"""Test Wyoming switch devices."""
|
||||
from homeassistant.components.wyoming.devices import SatelliteDevice
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import STATE_OFF, STATE_ON
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
|
||||
async def test_satellite_enabled(
|
||||
hass: HomeAssistant,
|
||||
satellite_config_entry: ConfigEntry,
|
||||
satellite_device: SatelliteDevice,
|
||||
) -> None:
|
||||
"""Test satellite enabled."""
|
||||
satellite_enabled_id = satellite_device.get_satellite_enabled_entity_id(hass)
|
||||
assert satellite_enabled_id
|
||||
|
||||
state = hass.states.get(satellite_enabled_id)
|
||||
assert state is not None
|
||||
assert state.state == STATE_ON
|
||||
assert satellite_device.is_enabled
|
||||
|
||||
await hass.services.async_call(
|
||||
"switch",
|
||||
"turn_off",
|
||||
{"entity_id": satellite_enabled_id},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
state = hass.states.get(satellite_enabled_id)
|
||||
assert state is not None
|
||||
assert state.state == STATE_OFF
|
||||
assert not satellite_device.is_enabled
|
Loading…
Reference in New Issue