Add Wyoming satellite (#104759)

* First draft of Wyoming satellite

* Set up homeassistant in tests

* Move satellite

* Add devices with binary sensor and select

* Add more events

* Add satellite enabled switch

* Fix mistake

* Only set up necessary platforms for satellites

* Lots of fixes

* Add tests

* Use config entry id as satellite id

* Initial satellite test

* Add satellite pipeline test

* More tests

* More satellite tests

* Only support single device per config entry

* Address comments

* Make a copy of platforms
pull/105135/head
Michael Hansen 2023-12-04 14:13:15 -06:00 committed by Franck Nijhof
parent db6b804298
commit 5a49e1dd5c
No known key found for this signature in database
GPG Key ID: D62583BA8AB11CA3
28 changed files with 1802 additions and 60 deletions

View File

@ -4,17 +4,26 @@ from __future__ import annotations
import logging import logging
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import device_registry as dr
from .const import ATTR_SPEAKER, DOMAIN from .const import ATTR_SPEAKER, DOMAIN
from .data import WyomingService from .data import WyomingService
from .devices import SatelliteDevice
from .models import DomainDataItem
from .satellite import WyomingSatellite
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SATELLITE_PLATFORMS = [Platform.BINARY_SENSOR, Platform.SELECT, Platform.SWITCH]
__all__ = [ __all__ = [
"ATTR_SPEAKER", "ATTR_SPEAKER",
"DOMAIN", "DOMAIN",
"async_setup_entry",
"async_unload_entry",
] ]
@ -25,24 +34,72 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
if service is None: if service is None:
raise ConfigEntryNotReady("Unable to connect") raise ConfigEntryNotReady("Unable to connect")
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = service item = DomainDataItem(service=service)
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = item
await hass.config_entries.async_forward_entry_setups( await hass.config_entries.async_forward_entry_setups(entry, service.platforms)
entry, entry.async_on_unload(entry.add_update_listener(update_listener))
service.platforms,
) if (satellite_info := service.info.satellite) is not None:
# Create satellite device, etc.
item.satellite = _make_satellite(hass, entry, service)
# Set up satellite sensors, switches, etc.
await hass.config_entries.async_forward_entry_setups(entry, SATELLITE_PLATFORMS)
# Start satellite communication
entry.async_create_background_task(
hass,
item.satellite.run(),
f"Satellite {satellite_info.name}",
)
entry.async_on_unload(item.satellite.stop)
return True return True
def _make_satellite(
hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService
) -> WyomingSatellite:
"""Create Wyoming satellite/device from config entry and Wyoming service."""
satellite_info = service.info.satellite
assert satellite_info is not None
dev_reg = dr.async_get(hass)
# Use config entry id since only one satellite per entry is supported
satellite_id = config_entry.entry_id
device = dev_reg.async_get_or_create(
config_entry_id=config_entry.entry_id,
identifiers={(DOMAIN, satellite_id)},
name=satellite_info.name,
suggested_area=satellite_info.area,
)
satellite_device = SatelliteDevice(
satellite_id=satellite_id,
device_id=device.id,
)
return WyomingSatellite(hass, service, satellite_device)
async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
"""Handle options update."""
await hass.config_entries.async_reload(entry.entry_id)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload Wyoming.""" """Unload Wyoming."""
service: WyomingService = hass.data[DOMAIN][entry.entry_id] item: DomainDataItem = hass.data[DOMAIN][entry.entry_id]
unload_ok = await hass.config_entries.async_unload_platforms( platforms = list(item.service.platforms)
entry, if item.satellite is not None:
service.platforms, platforms += SATELLITE_PLATFORMS
)
unload_ok = await hass.config_entries.async_unload_platforms(entry, platforms)
if unload_ok: if unload_ok:
del hass.data[DOMAIN][entry.entry_id] del hass.data[DOMAIN][entry.entry_id]

View File

@ -0,0 +1,55 @@
"""Binary sensor for Wyoming."""
from __future__ import annotations
from typing import TYPE_CHECKING
from homeassistant.components.binary_sensor import (
BinarySensorEntity,
BinarySensorEntityDescription,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
from .entity import WyomingSatelliteEntity
if TYPE_CHECKING:
from .models import DomainDataItem
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up binary sensor entities."""
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
# Setup is only forwarded for satellites
assert item.satellite is not None
async_add_entities([WyomingSatelliteAssistInProgress(item.satellite.device)])
class WyomingSatelliteAssistInProgress(WyomingSatelliteEntity, BinarySensorEntity):
"""Entity to represent Assist is in progress for satellite."""
entity_description = BinarySensorEntityDescription(
key="assist_in_progress",
translation_key="assist_in_progress",
)
_attr_is_on = False
async def async_added_to_hass(self) -> None:
"""Call when entity about to be added to hass."""
await super().async_added_to_hass()
self._device.set_is_active_listener(self._is_active_changed)
@callback
def _is_active_changed(self) -> None:
"""Call when active state changed."""
self._attr_is_on = self._device.is_active
self.async_write_ha_state()

View File

@ -1,19 +1,22 @@
"""Config flow for Wyoming integration.""" """Config flow for Wyoming integration."""
from __future__ import annotations from __future__ import annotations
import logging
from typing import Any from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.hassio import HassioServiceInfo from homeassistant.components import hassio, zeroconf
from homeassistant.const import CONF_HOST, CONF_PORT from homeassistant.const import CONF_HOST, CONF_NAME, CONF_PORT
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from .const import DOMAIN from .const import DOMAIN
from .data import WyomingService from .data import WyomingService
_LOGGER = logging.getLogger()
STEP_USER_DATA_SCHEMA = vol.Schema( STEP_USER_DATA_SCHEMA = vol.Schema(
{ {
vol.Required(CONF_HOST): str, vol.Required(CONF_HOST): str,
@ -27,7 +30,9 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
VERSION = 1 VERSION = 1
_hassio_discovery: HassioServiceInfo _hassio_discovery: hassio.HassioServiceInfo
_service: WyomingService | None = None
_name: str | None = None
async def async_step_user( async def async_step_user(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
@ -50,27 +55,14 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
errors={"base": "cannot_connect"}, errors={"base": "cannot_connect"},
) )
# ASR = automated speech recognition (speech-to-text) if name := service.get_name():
asr_installed = [asr for asr in service.info.asr if asr.installed] return self.async_create_entry(title=name, data=user_input)
# TTS = text-to-speech return self.async_abort(reason="no_services")
tts_installed = [tts for tts in service.info.tts if tts.installed]
# wake-word-detection async def async_step_hassio(
wake_installed = [wake for wake in service.info.wake if wake.installed] self, discovery_info: hassio.HassioServiceInfo
) -> FlowResult:
if asr_installed:
name = asr_installed[0].name
elif tts_installed:
name = tts_installed[0].name
elif wake_installed:
name = wake_installed[0].name
else:
return self.async_abort(reason="no_services")
return self.async_create_entry(title=name, data=user_input)
async def async_step_hassio(self, discovery_info: HassioServiceInfo) -> FlowResult:
"""Handle Supervisor add-on discovery.""" """Handle Supervisor add-on discovery."""
await self.async_set_unique_id(discovery_info.uuid) await self.async_set_unique_id(discovery_info.uuid)
self._abort_if_unique_id_configured() self._abort_if_unique_id_configured()
@ -93,11 +85,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
if user_input is not None: if user_input is not None:
uri = urlparse(self._hassio_discovery.config["uri"]) uri = urlparse(self._hassio_discovery.config["uri"])
if service := await WyomingService.create(uri.hostname, uri.port): if service := await WyomingService.create(uri.hostname, uri.port):
if ( if not service.has_services():
not any(asr for asr in service.info.asr if asr.installed)
and not any(tts for tts in service.info.tts if tts.installed)
and not any(wake for wake in service.info.wake if wake.installed)
):
return self.async_abort(reason="no_services") return self.async_abort(reason="no_services")
return self.async_create_entry( return self.async_create_entry(
@ -112,3 +100,52 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
description_placeholders={"addon": self._hassio_discovery.name}, description_placeholders={"addon": self._hassio_discovery.name},
errors=errors, errors=errors,
) )
async def async_step_zeroconf(
self, discovery_info: zeroconf.ZeroconfServiceInfo
) -> FlowResult:
"""Handle zeroconf discovery."""
_LOGGER.debug("Discovery info: %s", discovery_info)
if discovery_info.port is None:
return self.async_abort(reason="no_port")
service = await WyomingService.create(discovery_info.host, discovery_info.port)
if (service is None) or (not (name := service.get_name())):
# No supported services
return self.async_abort(reason="no_services")
self._name = name
# Use zeroconf name + service name as unique id.
# The satellite will use its own MAC as the zeroconf name by default.
unique_id = f"{discovery_info.name}_{self._name}"
await self.async_set_unique_id(unique_id)
self._abort_if_unique_id_configured()
self.context[CONF_NAME] = self._name
self.context["title_placeholders"] = {"name": self._name}
self._service = service
return await self.async_step_zeroconf_confirm()
async def async_step_zeroconf_confirm(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle a flow initiated by zeroconf."""
assert self._service is not None
assert self._name is not None
if user_input is None:
return self.async_show_form(
step_id="zeroconf_confirm",
description_placeholders={"name": self._name},
errors={},
)
return self.async_create_entry(
title=self._name,
data={
CONF_HOST: self._service.host,
CONF_PORT: self._service.port,
},
)

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import asyncio import asyncio
from wyoming.client import AsyncTcpClient from wyoming.client import AsyncTcpClient
from wyoming.info import Describe, Info from wyoming.info import Describe, Info, Satellite
from homeassistant.const import Platform from homeassistant.const import Platform
@ -32,6 +32,43 @@ class WyomingService:
platforms.append(Platform.WAKE_WORD) platforms.append(Platform.WAKE_WORD)
self.platforms = platforms self.platforms = platforms
def has_services(self) -> bool:
"""Return True if services are installed that Home Assistant can use."""
return (
any(asr for asr in self.info.asr if asr.installed)
or any(tts for tts in self.info.tts if tts.installed)
or any(wake for wake in self.info.wake if wake.installed)
or ((self.info.satellite is not None) and self.info.satellite.installed)
)
def get_name(self) -> str | None:
"""Return name of first installed usable service."""
# ASR = automated speech recognition (speech-to-text)
asr_installed = [asr for asr in self.info.asr if asr.installed]
if asr_installed:
return asr_installed[0].name
# TTS = text-to-speech
tts_installed = [tts for tts in self.info.tts if tts.installed]
if tts_installed:
return tts_installed[0].name
# wake-word-detection
wake_installed = [wake for wake in self.info.wake if wake.installed]
if wake_installed:
return wake_installed[0].name
# satellite
satellite_installed: Satellite | None = None
if (self.info.satellite is not None) and self.info.satellite.installed:
satellite_installed = self.info.satellite
if satellite_installed:
return satellite_installed.name
return None
@classmethod @classmethod
async def create(cls, host: str, port: int) -> WyomingService | None: async def create(cls, host: str, port: int) -> WyomingService | None:
"""Create a Wyoming service.""" """Create a Wyoming service."""

View File

@ -0,0 +1,85 @@
"""Class to manage satellite devices."""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import entity_registry as er
from .const import DOMAIN
@dataclass
class SatelliteDevice:
"""Class to store device."""
satellite_id: str
device_id: str
is_active: bool = False
is_enabled: bool = True
pipeline_name: str | None = None
_is_active_listener: Callable[[], None] | None = None
_is_enabled_listener: Callable[[], None] | None = None
_pipeline_listener: Callable[[], None] | None = None
@callback
def set_is_active(self, active: bool) -> None:
"""Set active state."""
if active != self.is_active:
self.is_active = active
if self._is_active_listener is not None:
self._is_active_listener()
@callback
def set_is_enabled(self, enabled: bool) -> None:
"""Set enabled state."""
if enabled != self.is_enabled:
self.is_enabled = enabled
if self._is_enabled_listener is not None:
self._is_enabled_listener()
@callback
def set_pipeline_name(self, pipeline_name: str) -> None:
"""Inform listeners that pipeline selection has changed."""
if pipeline_name != self.pipeline_name:
self.pipeline_name = pipeline_name
if self._pipeline_listener is not None:
self._pipeline_listener()
@callback
def set_is_active_listener(self, is_active_listener: Callable[[], None]) -> None:
"""Listen for updates to is_active."""
self._is_active_listener = is_active_listener
@callback
def set_is_enabled_listener(self, is_enabled_listener: Callable[[], None]) -> None:
"""Listen for updates to is_enabled."""
self._is_enabled_listener = is_enabled_listener
@callback
def set_pipeline_listener(self, pipeline_listener: Callable[[], None]) -> None:
"""Listen for updates to pipeline."""
self._pipeline_listener = pipeline_listener
def get_assist_in_progress_entity_id(self, hass: HomeAssistant) -> str | None:
"""Return entity id for assist in progress binary sensor."""
ent_reg = er.async_get(hass)
return ent_reg.async_get_entity_id(
"binary_sensor", DOMAIN, f"{self.satellite_id}-assist_in_progress"
)
def get_satellite_enabled_entity_id(self, hass: HomeAssistant) -> str | None:
"""Return entity id for satellite enabled switch."""
ent_reg = er.async_get(hass)
return ent_reg.async_get_entity_id(
"switch", DOMAIN, f"{self.satellite_id}-satellite_enabled"
)
def get_pipeline_entity_id(self, hass: HomeAssistant) -> str | None:
"""Return entity id for pipeline select."""
ent_reg = er.async_get(hass)
return ent_reg.async_get_entity_id(
"select", DOMAIN, f"{self.satellite_id}-pipeline"
)

View File

@ -0,0 +1,24 @@
"""Wyoming entities."""
from __future__ import annotations
from homeassistant.helpers import entity
from homeassistant.helpers.device_registry import DeviceInfo
from .const import DOMAIN
from .satellite import SatelliteDevice
class WyomingSatelliteEntity(entity.Entity):
"""Wyoming satellite entity."""
_attr_has_entity_name = True
_attr_should_poll = False
def __init__(self, device: SatelliteDevice) -> None:
"""Initialize entity."""
self._device = device
self._attr_unique_id = f"{device.satellite_id}-{self.entity_description.key}"
self._attr_device_info = DeviceInfo(
identifiers={(DOMAIN, device.satellite_id)},
)

View File

@ -3,7 +3,9 @@
"name": "Wyoming Protocol", "name": "Wyoming Protocol",
"codeowners": ["@balloob", "@synesthesiam"], "codeowners": ["@balloob", "@synesthesiam"],
"config_flow": true, "config_flow": true,
"dependencies": ["assist_pipeline"],
"documentation": "https://www.home-assistant.io/integrations/wyoming", "documentation": "https://www.home-assistant.io/integrations/wyoming",
"iot_class": "local_push", "iot_class": "local_push",
"requirements": ["wyoming==1.2.0"] "requirements": ["wyoming==1.3.0"],
"zeroconf": ["_wyoming._tcp.local."]
} }

View File

@ -0,0 +1,13 @@
"""Models for wyoming."""
from dataclasses import dataclass
from .data import WyomingService
from .satellite import WyomingSatellite
@dataclass
class DomainDataItem:
"""Domain data item."""
service: WyomingService
satellite: WyomingSatellite | None = None

View File

@ -0,0 +1,380 @@
"""Support for Wyoming satellite services."""
import asyncio
from collections.abc import AsyncGenerator
import io
import logging
from typing import Final
import wave
from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop
from wyoming.client import AsyncTcpClient
from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import RunSatellite
from wyoming.tts import Synthesize, SynthesizeVoice
from wyoming.vad import VoiceStarted, VoiceStopped
from wyoming.wake import Detect, Detection
from homeassistant.components import assist_pipeline, stt, tts
from homeassistant.components.assist_pipeline import select as pipeline_select
from homeassistant.core import Context, HomeAssistant
from .const import DOMAIN
from .data import WyomingService
from .devices import SatelliteDevice
_LOGGER = logging.getLogger()
_SAMPLES_PER_CHUNK: Final = 1024
_RECONNECT_SECONDS: Final = 10
_RESTART_SECONDS: Final = 3
# Wyoming stage -> Assist stage
_STAGES: dict[PipelineStage, assist_pipeline.PipelineStage] = {
PipelineStage.WAKE: assist_pipeline.PipelineStage.WAKE_WORD,
PipelineStage.ASR: assist_pipeline.PipelineStage.STT,
PipelineStage.HANDLE: assist_pipeline.PipelineStage.INTENT,
PipelineStage.TTS: assist_pipeline.PipelineStage.TTS,
}
class WyomingSatellite:
"""Remove voice satellite running the Wyoming protocol."""
def __init__(
self, hass: HomeAssistant, service: WyomingService, device: SatelliteDevice
) -> None:
"""Initialize satellite."""
self.hass = hass
self.service = service
self.device = device
self.is_enabled = True
self.is_running = True
self._client: AsyncTcpClient | None = None
self._chunk_converter = AudioChunkConverter(rate=16000, width=2, channels=1)
self._is_pipeline_running = False
self._audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
self._pipeline_id: str | None = None
self._enabled_changed_event = asyncio.Event()
self.device.set_is_enabled_listener(self._enabled_changed)
self.device.set_pipeline_listener(self._pipeline_changed)
async def run(self) -> None:
"""Run and maintain a connection to satellite."""
_LOGGER.debug("Running satellite task")
try:
while self.is_running:
try:
# Check if satellite has been disabled
if not self.device.is_enabled:
await self.on_disabled()
if not self.is_running:
# Satellite was stopped while waiting to be enabled
break
# Connect and run pipeline loop
await self._run_once()
except asyncio.CancelledError:
raise
except Exception: # pylint: disable=broad-exception-caught
await self.on_restart()
finally:
# Ensure sensor is off
self.device.set_is_active(False)
await self.on_stopped()
def stop(self) -> None:
"""Signal satellite task to stop running."""
self.is_running = False
# Unblock waiting for enabled
self._enabled_changed_event.set()
async def on_restart(self) -> None:
"""Block until pipeline loop will be restarted."""
_LOGGER.warning(
"Unexpected error running satellite. Restarting in %s second(s)",
_RECONNECT_SECONDS,
)
await asyncio.sleep(_RESTART_SECONDS)
async def on_reconnect(self) -> None:
"""Block until a reconnection attempt should be made."""
_LOGGER.debug(
"Failed to connect to satellite. Reconnecting in %s second(s)",
_RECONNECT_SECONDS,
)
await asyncio.sleep(_RECONNECT_SECONDS)
async def on_disabled(self) -> None:
"""Block until device may be enabled again."""
await self._enabled_changed_event.wait()
async def on_stopped(self) -> None:
"""Run when run() has fully stopped."""
_LOGGER.debug("Satellite task stopped")
# -------------------------------------------------------------------------
def _enabled_changed(self) -> None:
"""Run when device enabled status changes."""
if not self.device.is_enabled:
# Cancel any running pipeline
self._audio_queue.put_nowait(None)
self._enabled_changed_event.set()
def _pipeline_changed(self) -> None:
"""Run when device pipeline changes."""
# Cancel any running pipeline
self._audio_queue.put_nowait(None)
async def _run_once(self) -> None:
"""Run pipelines until an error occurs."""
self.device.set_is_active(False)
while self.is_running and self.is_enabled:
try:
await self._connect()
break
except ConnectionError:
await self.on_reconnect()
assert self._client is not None
_LOGGER.debug("Connected to satellite")
if (not self.is_running) or (not self.is_enabled):
# Run was cancelled or satellite was disabled during connection
return
# Tell satellite that we're ready
await self._client.write_event(RunSatellite().event())
# Wait until we get RunPipeline event
run_pipeline: RunPipeline | None = None
while self.is_running and self.is_enabled:
run_event = await self._client.read_event()
if run_event is None:
raise ConnectionResetError("Satellite disconnected")
if RunPipeline.is_type(run_event.type):
run_pipeline = RunPipeline.from_event(run_event)
break
_LOGGER.debug("Unexpected event from satellite: %s", run_event)
assert run_pipeline is not None
_LOGGER.debug("Received run information: %s", run_pipeline)
if (not self.is_running) or (not self.is_enabled):
# Run was cancelled or satellite was disabled while waiting for
# RunPipeline event.
return
start_stage = _STAGES.get(run_pipeline.start_stage)
end_stage = _STAGES.get(run_pipeline.end_stage)
if start_stage is None:
raise ValueError(f"Invalid start stage: {start_stage}")
if end_stage is None:
raise ValueError(f"Invalid end stage: {end_stage}")
# Each loop is a pipeline run
while self.is_running and self.is_enabled:
# Use select to get pipeline each time in case it's changed
pipeline_id = pipeline_select.get_chosen_pipeline(
self.hass,
DOMAIN,
self.device.satellite_id,
)
pipeline = assist_pipeline.async_get_pipeline(self.hass, pipeline_id)
assert pipeline is not None
# We will push audio in through a queue
self._audio_queue = asyncio.Queue()
stt_stream = self._stt_stream()
# Start pipeline running
_LOGGER.debug(
"Starting pipeline %s from %s to %s",
pipeline.name,
start_stage,
end_stage,
)
self._is_pipeline_running = True
_pipeline_task = asyncio.create_task(
assist_pipeline.async_pipeline_from_audio_stream(
self.hass,
context=Context(),
event_callback=self._event_callback,
stt_metadata=stt.SpeechMetadata(
language=pipeline.language,
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=stt_stream,
start_stage=start_stage,
end_stage=end_stage,
tts_audio_output="wav",
pipeline_id=pipeline_id,
)
)
# Run until pipeline is complete or cancelled with an empty audio chunk
while self._is_pipeline_running:
client_event = await self._client.read_event()
if client_event is None:
raise ConnectionResetError("Satellite disconnected")
if AudioChunk.is_type(client_event.type):
# Microphone audio
chunk = AudioChunk.from_event(client_event)
chunk = self._chunk_converter.convert(chunk)
self._audio_queue.put_nowait(chunk.audio)
else:
_LOGGER.debug("Unexpected event from satellite: %s", client_event)
_LOGGER.debug("Pipeline finished")
def _event_callback(self, event: assist_pipeline.PipelineEvent) -> None:
"""Translate pipeline events into Wyoming events."""
assert self._client is not None
if event.type == assist_pipeline.PipelineEventType.RUN_END:
# Pipeline run is complete
self._is_pipeline_running = False
self.device.set_is_active(False)
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
self.hass.add_job(self._client.write_event(Detect().event()))
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END:
# Wake word detection
self.device.set_is_active(True)
# Inform client of wake word detection
if event.data and (wake_word_output := event.data.get("wake_word_output")):
detection = Detection(
name=wake_word_output["wake_word_id"],
timestamp=wake_word_output.get("timestamp"),
)
self.hass.add_job(self._client.write_event(detection.event()))
elif event.type == assist_pipeline.PipelineEventType.STT_START:
# Speech-to-text
self.device.set_is_active(True)
if event.data:
self.hass.add_job(
self._client.write_event(
Transcribe(language=event.data["metadata"]["language"]).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START:
# User started speaking
if event.data:
self.hass.add_job(
self._client.write_event(
VoiceStarted(timestamp=event.data["timestamp"]).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END:
# User stopped speaking
if event.data:
self.hass.add_job(
self._client.write_event(
VoiceStopped(timestamp=event.data["timestamp"]).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.STT_END:
# Speech-to-text transcript
if event.data:
# Inform client of transript
stt_text = event.data["stt_output"]["text"]
self.hass.add_job(
self._client.write_event(Transcript(text=stt_text).event())
)
elif event.type == assist_pipeline.PipelineEventType.TTS_START:
# Text-to-speech text
if event.data:
# Inform client of text
self.hass.add_job(
self._client.write_event(
Synthesize(
text=event.data["tts_input"],
voice=SynthesizeVoice(
name=event.data.get("voice"),
language=event.data.get("language"),
),
).event()
)
)
elif event.type == assist_pipeline.PipelineEventType.TTS_END:
# TTS stream
if event.data and (tts_output := event.data["tts_output"]):
media_id = tts_output["media_id"]
self.hass.add_job(self._stream_tts(media_id))
async def _connect(self) -> None:
"""Connect to satellite over TCP."""
_LOGGER.debug(
"Connecting to satellite at %s:%s", self.service.host, self.service.port
)
self._client = AsyncTcpClient(self.service.host, self.service.port)
await self._client.connect()
async def _stream_tts(self, media_id: str) -> None:
"""Stream TTS WAV audio to satellite in chunks."""
assert self._client is not None
extension, data = await tts.async_get_media_source_audio(self.hass, media_id)
if extension != "wav":
raise ValueError(f"Cannot stream audio format to satellite: {extension}")
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()
sample_width = wav_file.getsampwidth()
sample_channels = wav_file.getnchannels()
_LOGGER.debug("Streaming %s TTS sample(s)", wav_file.getnframes())
timestamp = 0
await self._client.write_event(
AudioStart(
rate=sample_rate,
width=sample_width,
channels=sample_channels,
timestamp=timestamp,
).event()
)
# Stream audio chunks
while audio_bytes := wav_file.readframes(_SAMPLES_PER_CHUNK):
chunk = AudioChunk(
rate=sample_rate,
width=sample_width,
channels=sample_channels,
audio=audio_bytes,
timestamp=timestamp,
)
await self._client.write_event(chunk.event())
timestamp += chunk.seconds
await self._client.write_event(AudioStop(timestamp=timestamp).event())
_LOGGER.debug("TTS streaming complete")
async def _stt_stream(self) -> AsyncGenerator[bytes, None]:
"""Yield audio chunks from a queue."""
is_first_chunk = True
while chunk := await self._audio_queue.get():
if is_first_chunk:
is_first_chunk = False
_LOGGER.debug("Receiving audio from satellite")
yield chunk

View File

@ -0,0 +1,47 @@
"""Select entities for VoIP integration."""
from __future__ import annotations
from typing import TYPE_CHECKING
from homeassistant.components.assist_pipeline.select import AssistPipelineSelect
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
from .devices import SatelliteDevice
from .entity import WyomingSatelliteEntity
if TYPE_CHECKING:
from .models import DomainDataItem
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up VoIP switch entities."""
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
# Setup is only forwarded for satellites
assert item.satellite is not None
async_add_entities([WyomingSatellitePipelineSelect(hass, item.satellite.device)])
class WyomingSatellitePipelineSelect(WyomingSatelliteEntity, AssistPipelineSelect):
"""Pipeline selector for Wyoming satellites."""
def __init__(self, hass: HomeAssistant, device: SatelliteDevice) -> None:
"""Initialize a pipeline selector."""
self.device = device
WyomingSatelliteEntity.__init__(self, device)
AssistPipelineSelect.__init__(self, hass, device.satellite_id)
async def async_select_option(self, option: str) -> None:
"""Select an option."""
await super().async_select_option(option)
self.device.set_pipeline_name(option)

View File

@ -9,6 +9,10 @@
}, },
"hassio_confirm": { "hassio_confirm": {
"description": "Do you want to configure Home Assistant to connect to the Wyoming service provided by the add-on: {addon}?" "description": "Do you want to configure Home Assistant to connect to the Wyoming service provided by the add-on: {addon}?"
},
"zeroconf_confirm": {
"description": "Do you want to configure Home Assistant to connect to the Wyoming service {name}?",
"title": "Discovered Wyoming service"
} }
}, },
"error": { "error": {
@ -16,7 +20,31 @@
}, },
"abort": { "abort": {
"already_configured": "[%key:common::config_flow::abort::already_configured_service%]", "already_configured": "[%key:common::config_flow::abort::already_configured_service%]",
"no_services": "No services found at endpoint" "no_services": "No services found at endpoint",
"no_port": "No port for endpoint"
}
},
"entity": {
"binary_sensor": {
"assist_in_progress": {
"name": "[%key:component::assist_pipeline::entity::binary_sensor::assist_in_progress::name%]"
}
},
"select": {
"pipeline": {
"name": "[%key:component::assist_pipeline::entity::select::pipeline::name%]",
"state": {
"preferred": "[%key:component::assist_pipeline::entity::select::pipeline::state::preferred%]"
}
},
"noise_suppression": {
"name": "Noise suppression"
}
},
"switch": {
"satellite_enabled": {
"name": "Satellite enabled"
}
} }
} }
} }

View File

@ -14,6 +14,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN, SAMPLE_CHANNELS, SAMPLE_RATE, SAMPLE_WIDTH from .const import DOMAIN, SAMPLE_CHANNELS, SAMPLE_RATE, SAMPLE_WIDTH
from .data import WyomingService from .data import WyomingService
from .error import WyomingError from .error import WyomingError
from .models import DomainDataItem
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -24,10 +25,10 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up Wyoming speech-to-text.""" """Set up Wyoming speech-to-text."""
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id] item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
async_add_entities( async_add_entities(
[ [
WyomingSttProvider(config_entry, service), WyomingSttProvider(config_entry, item.service),
] ]
) )

View File

@ -0,0 +1,65 @@
"""Wyoming switch entities."""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from homeassistant.components.switch import SwitchEntity, SwitchEntityDescription
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import STATE_ON, EntityCategory
from homeassistant.core import HomeAssistant
from homeassistant.helpers import restore_state
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN
from .entity import WyomingSatelliteEntity
if TYPE_CHECKING:
from .models import DomainDataItem
async def async_setup_entry(
hass: HomeAssistant,
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up VoIP switch entities."""
item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
# Setup is only forwarded for satellites
assert item.satellite is not None
async_add_entities([WyomingSatelliteEnabledSwitch(item.satellite.device)])
class WyomingSatelliteEnabledSwitch(
WyomingSatelliteEntity, restore_state.RestoreEntity, SwitchEntity
):
"""Entity to represent if satellite is enabled."""
entity_description = SwitchEntityDescription(
key="satellite_enabled",
translation_key="satellite_enabled",
entity_category=EntityCategory.CONFIG,
)
async def async_added_to_hass(self) -> None:
"""Call when entity about to be added to hass."""
await super().async_added_to_hass()
state = await self.async_get_last_state()
# Default to on
self._attr_is_on = (state is None) or (state.state == STATE_ON)
async def async_turn_on(self, **kwargs: Any) -> None:
"""Turn on."""
self._attr_is_on = True
self.async_write_ha_state()
self._device.set_is_enabled(True)
async def async_turn_off(self, **kwargs: Any) -> None:
"""Turn off."""
self._attr_is_on = False
self.async_write_ha_state()
self._device.set_is_enabled(False)

View File

@ -16,6 +16,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import ATTR_SPEAKER, DOMAIN from .const import ATTR_SPEAKER, DOMAIN
from .data import WyomingService from .data import WyomingService
from .error import WyomingError from .error import WyomingError
from .models import DomainDataItem
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -26,10 +27,10 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up Wyoming speech-to-text.""" """Set up Wyoming speech-to-text."""
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id] item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
async_add_entities( async_add_entities(
[ [
WyomingTtsProvider(config_entry, service), WyomingTtsProvider(config_entry, item.service),
] ]
) )

View File

@ -15,6 +15,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DOMAIN from .const import DOMAIN
from .data import WyomingService, load_wyoming_info from .data import WyomingService, load_wyoming_info
from .error import WyomingError from .error import WyomingError
from .models import DomainDataItem
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -25,10 +26,10 @@ async def async_setup_entry(
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up Wyoming wake-word-detection.""" """Set up Wyoming wake-word-detection."""
service: WyomingService = hass.data[DOMAIN][config_entry.entry_id] item: DomainDataItem = hass.data[DOMAIN][config_entry.entry_id]
async_add_entities( async_add_entities(
[ [
WyomingWakeWordProvider(hass, config_entry, service), WyomingWakeWordProvider(hass, config_entry, item.service),
] ]
) )

View File

@ -715,6 +715,11 @@ ZEROCONF = {
"domain": "wled", "domain": "wled",
}, },
], ],
"_wyoming._tcp.local.": [
{
"domain": "wyoming",
},
],
"_xbmc-jsonrpc-h._tcp.local.": [ "_xbmc-jsonrpc-h._tcp.local.": [
{ {
"domain": "kodi", "domain": "kodi",

View File

@ -2750,7 +2750,7 @@ wled==0.17.0
wolf-smartset==0.1.11 wolf-smartset==0.1.11
# homeassistant.components.wyoming # homeassistant.components.wyoming
wyoming==1.2.0 wyoming==1.3.0
# homeassistant.components.xbox # homeassistant.components.xbox
xbox-webapi==2.0.11 xbox-webapi==2.0.11

View File

@ -2054,7 +2054,7 @@ wled==0.17.0
wolf-smartset==0.1.11 wolf-smartset==0.1.11
# homeassistant.components.wyoming # homeassistant.components.wyoming
wyoming==1.2.0 wyoming==1.3.0
# homeassistant.components.xbox # homeassistant.components.xbox
xbox-webapi==2.0.11 xbox-webapi==2.0.11

View File

@ -1,11 +1,13 @@
"""Tests for the Wyoming integration.""" """Tests for the Wyoming integration."""
import asyncio import asyncio
from wyoming.event import Event
from wyoming.info import ( from wyoming.info import (
AsrModel, AsrModel,
AsrProgram, AsrProgram,
Attribution, Attribution,
Info, Info,
Satellite,
TtsProgram, TtsProgram,
TtsVoice, TtsVoice,
TtsVoiceSpeaker, TtsVoiceSpeaker,
@ -72,24 +74,36 @@ WAKE_WORD_INFO = Info(
) )
] ]
) )
SATELLITE_INFO = Info(
satellite=Satellite(
name="Test Satellite",
description="Test Satellite",
installed=True,
attribution=TEST_ATTR,
area="Office",
)
)
EMPTY_INFO = Info() EMPTY_INFO = Info()
class MockAsyncTcpClient: class MockAsyncTcpClient:
"""Mock AsyncTcpClient.""" """Mock AsyncTcpClient."""
def __init__(self, responses) -> None: def __init__(self, responses: list[Event]) -> None:
"""Initialize.""" """Initialize."""
self.host = None self.host: str | None = None
self.port = None self.port: int | None = None
self.written = [] self.written: list[Event] = []
self.responses = responses self.responses = responses
async def write_event(self, event): async def connect(self) -> None:
"""Connect."""
async def write_event(self, event: Event):
"""Send.""" """Send."""
self.written.append(event) self.written.append(event)
async def read_event(self): async def read_event(self) -> Event | None:
"""Receive.""" """Receive."""
await asyncio.sleep(0) # force context switch await asyncio.sleep(0) # force context switch
@ -105,7 +119,7 @@ class MockAsyncTcpClient:
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
"""Exit.""" """Exit."""
def __call__(self, host, port): def __call__(self, host: str, port: int):
"""Call.""" """Call."""
self.host = host self.host = host
self.port = port self.port = port

View File

@ -5,14 +5,23 @@ from unittest.mock import AsyncMock, patch
import pytest import pytest
from homeassistant.components import stt from homeassistant.components import stt
from homeassistant.components.wyoming import DOMAIN
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from . import STT_INFO, TTS_INFO, WAKE_WORD_INFO from . import SATELLITE_INFO, STT_INFO, TTS_INFO, WAKE_WORD_INFO
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@pytest.fixture(autouse=True)
async def init_components(hass: HomeAssistant):
"""Set up required components."""
assert await async_setup_component(hass, "homeassistant", {})
@pytest.fixture @pytest.fixture
def mock_setup_entry() -> Generator[AsyncMock, None, None]: def mock_setup_entry() -> Generator[AsyncMock, None, None]:
"""Override async_setup_entry.""" """Override async_setup_entry."""
@ -110,3 +119,39 @@ def metadata(hass: HomeAssistant) -> stt.SpeechMetadata:
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO, channel=stt.AudioChannels.CHANNEL_MONO,
) )
@pytest.fixture
def satellite_config_entry(hass: HomeAssistant) -> ConfigEntry:
"""Create a config entry."""
entry = MockConfigEntry(
domain="wyoming",
data={
"host": "1.2.3.4",
"port": 1234,
},
title="Test Satellite",
)
entry.add_to_hass(hass)
return entry
@pytest.fixture
async def init_satellite(hass: HomeAssistant, satellite_config_entry: ConfigEntry):
"""Initialize Wyoming satellite."""
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
), patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.run"
) as _run_mock:
# _run_mock: satellite task does not actually run
await hass.config_entries.async_setup(satellite_config_entry.entry_id)
@pytest.fixture
async def satellite_device(
hass: HomeAssistant, init_satellite, satellite_config_entry: ConfigEntry
) -> SatelliteDevice:
"""Get a satellite device fixture."""
return hass.data[DOMAIN][satellite_config_entry.entry_id].satellite.device

View File

@ -121,3 +121,45 @@
'version': 1, 'version': 1,
}) })
# --- # ---
# name: test_zeroconf_discovery
FlowResultSnapshot({
'context': dict({
'name': 'Test Satellite',
'source': 'zeroconf',
'title_placeholders': dict({
'name': 'Test Satellite',
}),
'unique_id': 'test_zeroconf_name._wyoming._tcp.local._Test Satellite',
}),
'data': dict({
'host': '127.0.0.1',
'port': 12345,
}),
'description': None,
'description_placeholders': None,
'flow_id': <ANY>,
'handler': 'wyoming',
'options': dict({
}),
'result': ConfigEntrySnapshot({
'data': dict({
'host': '127.0.0.1',
'port': 12345,
}),
'disabled_by': None,
'domain': 'wyoming',
'entry_id': <ANY>,
'options': dict({
}),
'pref_disable_new_entities': False,
'pref_disable_polling': False,
'source': 'zeroconf',
'title': 'Test Satellite',
'unique_id': 'test_zeroconf_name._wyoming._tcp.local._Test Satellite',
'version': 1,
}),
'title': 'Test Satellite',
'type': <FlowResultType.CREATE_ENTRY: 'create_entry'>,
'version': 1,
})
# ---

View File

@ -0,0 +1,34 @@
"""Test Wyoming binary sensor devices."""
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant
async def test_assist_in_progress(
hass: HomeAssistant,
satellite_config_entry: ConfigEntry,
satellite_device: SatelliteDevice,
) -> None:
"""Test assist in progress."""
assist_in_progress_id = satellite_device.get_assist_in_progress_entity_id(hass)
assert assist_in_progress_id
state = hass.states.get(assist_in_progress_id)
assert state is not None
assert state.state == STATE_OFF
assert not satellite_device.is_active
satellite_device.set_is_active(True)
state = hass.states.get(assist_in_progress_id)
assert state is not None
assert state.state == STATE_ON
assert satellite_device.is_active
satellite_device.set_is_active(False)
state = hass.states.get(assist_in_progress_id)
assert state is not None
assert state.state == STATE_OFF
assert not satellite_device.is_active

View File

@ -1,4 +1,5 @@
"""Test the Wyoming config flow.""" """Test the Wyoming config flow."""
from ipaddress import IPv4Address
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
import pytest import pytest
@ -8,10 +9,11 @@ from wyoming.info import Info
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.hassio import HassioServiceInfo from homeassistant.components.hassio import HassioServiceInfo
from homeassistant.components.wyoming.const import DOMAIN from homeassistant.components.wyoming.const import DOMAIN
from homeassistant.components.zeroconf import ZeroconfServiceInfo
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
from . import EMPTY_INFO, STT_INFO, TTS_INFO from . import EMPTY_INFO, SATELLITE_INFO, STT_INFO, TTS_INFO
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -25,6 +27,16 @@ ADDON_DISCOVERY = HassioServiceInfo(
uuid="1234", uuid="1234",
) )
ZEROCONF_DISCOVERY = ZeroconfServiceInfo(
ip_address=IPv4Address("127.0.0.1"),
ip_addresses=[IPv4Address("127.0.0.1")],
port=12345,
hostname="localhost",
type="_wyoming._tcp.local.",
name="test_zeroconf_name._wyoming._tcp.local.",
properties={},
)
pytestmark = pytest.mark.usefixtures("mock_setup_entry") pytestmark = pytest.mark.usefixtures("mock_setup_entry")
@ -214,3 +226,70 @@ async def test_hassio_addon_no_supported_services(hass: HomeAssistant) -> None:
assert result2.get("type") == FlowResultType.ABORT assert result2.get("type") == FlowResultType.ABORT
assert result2.get("reason") == "no_services" assert result2.get("reason") == "no_services"
async def test_zeroconf_discovery(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
snapshot: SnapshotAssertion,
) -> None:
"""Test config flow initiated by Supervisor."""
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
):
result = await hass.config_entries.flow.async_init(
DOMAIN,
data=ZEROCONF_DISCOVERY,
context={"source": config_entries.SOURCE_ZEROCONF},
)
assert result.get("type") == FlowResultType.FORM
assert result.get("step_id") == "zeroconf_confirm"
assert result.get("description_placeholders") == {
"name": SATELLITE_INFO.satellite.name
}
result2 = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result2.get("type") == FlowResultType.CREATE_ENTRY
assert result2 == snapshot
async def test_zeroconf_discovery_no_port(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
snapshot: SnapshotAssertion,
) -> None:
"""Test discovery when the zeroconf service does not have a port."""
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
), patch.object(ZEROCONF_DISCOVERY, "port", None):
result = await hass.config_entries.flow.async_init(
DOMAIN,
data=ZEROCONF_DISCOVERY,
context={"source": config_entries.SOURCE_ZEROCONF},
)
assert result.get("type") == FlowResultType.ABORT
assert result.get("reason") == "no_port"
async def test_zeroconf_discovery_no_services(
hass: HomeAssistant,
mock_setup_entry: AsyncMock,
snapshot: SnapshotAssertion,
) -> None:
"""Test discovery when there are no supported services on the client."""
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=Info(),
):
result = await hass.config_entries.flow.async_init(
DOMAIN,
data=ZEROCONF_DISCOVERY,
context={"source": config_entries.SOURCE_ZEROCONF},
)
assert result.get("type") == FlowResultType.ABORT
assert result.get("reason") == "no_services"

View File

@ -3,13 +3,15 @@ from __future__ import annotations
from unittest.mock import patch from unittest.mock import patch
from homeassistant.components.wyoming.data import load_wyoming_info from syrupy.assertion import SnapshotAssertion
from homeassistant.components.wyoming.data import WyomingService, load_wyoming_info
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from . import STT_INFO, MockAsyncTcpClient from . import SATELLITE_INFO, STT_INFO, TTS_INFO, WAKE_WORD_INFO, MockAsyncTcpClient
async def test_load_info(hass: HomeAssistant, snapshot) -> None: async def test_load_info(hass: HomeAssistant, snapshot: SnapshotAssertion) -> None:
"""Test loading info.""" """Test loading info."""
with patch( with patch(
"homeassistant.components.wyoming.data.AsyncTcpClient", "homeassistant.components.wyoming.data.AsyncTcpClient",
@ -38,3 +40,38 @@ async def test_load_info_oserror(hass: HomeAssistant) -> None:
) )
assert info is None assert info is None
async def test_service_name(hass: HomeAssistant) -> None:
"""Test loading service info."""
with patch(
"homeassistant.components.wyoming.data.AsyncTcpClient",
MockAsyncTcpClient([STT_INFO.event()]),
):
service = await WyomingService.create("localhost", 1234)
assert service is not None
assert service.get_name() == STT_INFO.asr[0].name
with patch(
"homeassistant.components.wyoming.data.AsyncTcpClient",
MockAsyncTcpClient([TTS_INFO.event()]),
):
service = await WyomingService.create("localhost", 1234)
assert service is not None
assert service.get_name() == TTS_INFO.tts[0].name
with patch(
"homeassistant.components.wyoming.data.AsyncTcpClient",
MockAsyncTcpClient([WAKE_WORD_INFO.event()]),
):
service = await WyomingService.create("localhost", 1234)
assert service is not None
assert service.get_name() == WAKE_WORD_INFO.wake[0].name
with patch(
"homeassistant.components.wyoming.data.AsyncTcpClient",
MockAsyncTcpClient([SATELLITE_INFO.event()]),
):
service = await WyomingService.create("localhost", 1234)
assert service is not None
assert service.get_name() == SATELLITE_INFO.satellite.name

View File

@ -0,0 +1,78 @@
"""Test Wyoming devices."""
from __future__ import annotations
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
from homeassistant.components.wyoming import DOMAIN
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr
async def test_device_registry_info(
hass: HomeAssistant,
satellite_device: SatelliteDevice,
satellite_config_entry: ConfigEntry,
device_registry: dr.DeviceRegistry,
) -> None:
"""Test info in device registry."""
# Satellite uses config entry id since only one satellite per entry is
# supported.
device = device_registry.async_get_device(
identifiers={(DOMAIN, satellite_config_entry.entry_id)}
)
assert device is not None
assert device.name == "Test Satellite"
assert device.suggested_area == "Office"
# Check associated entities
assist_in_progress_id = satellite_device.get_assist_in_progress_entity_id(hass)
assert assist_in_progress_id
assist_in_progress_state = hass.states.get(assist_in_progress_id)
assert assist_in_progress_state is not None
assert assist_in_progress_state.state == STATE_OFF
satellite_enabled_id = satellite_device.get_satellite_enabled_entity_id(hass)
assert satellite_enabled_id
satellite_enabled_state = hass.states.get(satellite_enabled_id)
assert satellite_enabled_state is not None
assert satellite_enabled_state.state == STATE_ON
pipeline_entity_id = satellite_device.get_pipeline_entity_id(hass)
assert pipeline_entity_id
pipeline_state = hass.states.get(pipeline_entity_id)
assert pipeline_state is not None
assert pipeline_state.state == OPTION_PREFERRED
async def test_remove_device_registry_entry(
hass: HomeAssistant,
satellite_device: SatelliteDevice,
device_registry: dr.DeviceRegistry,
) -> None:
"""Test removing a device registry entry."""
# Check associated entities
assist_in_progress_id = satellite_device.get_assist_in_progress_entity_id(hass)
assert assist_in_progress_id
assert hass.states.get(assist_in_progress_id) is not None
satellite_enabled_id = satellite_device.get_satellite_enabled_entity_id(hass)
assert satellite_enabled_id
assert hass.states.get(satellite_enabled_id) is not None
pipeline_entity_id = satellite_device.get_pipeline_entity_id(hass)
assert pipeline_entity_id
assert hass.states.get(pipeline_entity_id) is not None
# Remove
device_registry.async_remove_device(satellite_device.device_id)
await hass.async_block_till_done()
await hass.async_block_till_done()
# Everything should be gone
assert hass.states.get(assist_in_progress_id) is None
assert hass.states.get(satellite_enabled_id) is None
assert hass.states.get(pipeline_entity_id) is None

View File

@ -0,0 +1,460 @@
"""Test Wyoming satellite."""
from __future__ import annotations
import asyncio
import io
from unittest.mock import patch
import wave
from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.event import Event
from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import RunSatellite
from wyoming.tts import Synthesize
from wyoming.vad import VoiceStarted, VoiceStopped
from wyoming.wake import Detect, Detection
from homeassistant.components import assist_pipeline, wyoming
from homeassistant.components.wyoming.data import WyomingService
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from . import SATELLITE_INFO, MockAsyncTcpClient
from tests.common import MockConfigEntry
async def setup_config_entry(hass: HomeAssistant) -> MockConfigEntry:
"""Set up config entry for Wyoming satellite.
This is separated from the satellite_config_entry method in conftest.py so
we can patch functions before the satellite task is run during setup.
"""
entry = MockConfigEntry(
domain="wyoming",
data={
"host": "1.2.3.4",
"port": 1234,
},
title="Test Satellite",
)
entry.add_to_hass(hass)
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
return entry
def get_test_wav() -> bytes:
"""Get bytes for test WAV file."""
with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(22050)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
# Single frame
wav_file.writeframes(b"123")
return wav_io.getvalue()
class SatelliteAsyncTcpClient(MockAsyncTcpClient):
"""Satellite AsyncTcpClient."""
def __init__(self, responses: list[Event]) -> None:
"""Initialize client."""
super().__init__(responses)
self.connect_event = asyncio.Event()
self.run_satellite_event = asyncio.Event()
self.detect_event = asyncio.Event()
self.detection_event = asyncio.Event()
self.detection: Detection | None = None
self.transcribe_event = asyncio.Event()
self.transcribe: Transcribe | None = None
self.voice_started_event = asyncio.Event()
self.voice_started: VoiceStarted | None = None
self.voice_stopped_event = asyncio.Event()
self.voice_stopped: VoiceStopped | None = None
self.transcript_event = asyncio.Event()
self.transcript: Transcript | None = None
self.synthesize_event = asyncio.Event()
self.synthesize: Synthesize | None = None
self.tts_audio_start_event = asyncio.Event()
self.tts_audio_chunk_event = asyncio.Event()
self.tts_audio_stop_event = asyncio.Event()
self.tts_audio_chunk: AudioChunk | None = None
self._mic_audio_chunk = AudioChunk(
rate=16000, width=2, channels=1, audio=b"chunk"
).event()
async def connect(self) -> None:
"""Connect."""
self.connect_event.set()
async def write_event(self, event: Event):
"""Send."""
if RunSatellite.is_type(event.type):
self.run_satellite_event.set()
elif Detect.is_type(event.type):
self.detect_event.set()
elif Detection.is_type(event.type):
self.detection = Detection.from_event(event)
self.detection_event.set()
elif Transcribe.is_type(event.type):
self.transcribe = Transcribe.from_event(event)
self.transcribe_event.set()
elif VoiceStarted.is_type(event.type):
self.voice_started = VoiceStarted.from_event(event)
self.voice_started_event.set()
elif VoiceStopped.is_type(event.type):
self.voice_stopped = VoiceStopped.from_event(event)
self.voice_stopped_event.set()
elif Transcript.is_type(event.type):
self.transcript = Transcript.from_event(event)
self.transcript_event.set()
elif Synthesize.is_type(event.type):
self.synthesize = Synthesize.from_event(event)
self.synthesize_event.set()
elif AudioStart.is_type(event.type):
self.tts_audio_start_event.set()
elif AudioChunk.is_type(event.type):
self.tts_audio_chunk = AudioChunk.from_event(event)
self.tts_audio_chunk_event.set()
elif AudioStop.is_type(event.type):
self.tts_audio_stop_event.set()
async def read_event(self) -> Event | None:
"""Receive."""
event = await super().read_event()
# Keep sending audio chunks instead of None
return event or self._mic_audio_chunk
async def test_satellite_pipeline(hass: HomeAssistant) -> None:
"""Test running a pipeline with a satellite."""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
events = [
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
]
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
), patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client, patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
) as mock_run_pipeline, patch(
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
return_value=("wav", get_test_wav()),
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
entry.entry_id
].satellite.device
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
mock_run_pipeline.assert_called()
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
# Start detecting wake word
event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.WAKE_WORD_START
)
)
async with asyncio.timeout(1):
await mock_client.detect_event.wait()
assert not device.is_active
assert device.is_enabled
# Wake word is detected
event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.WAKE_WORD_END,
{"wake_word_output": {"wake_word_id": "test_wake_word"}},
)
)
async with asyncio.timeout(1):
await mock_client.detection_event.wait()
assert mock_client.detection is not None
assert mock_client.detection.name == "test_wake_word"
# "Assist in progress" sensor should be active now
assert device.is_active
# Speech-to-text started
event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.STT_START,
{"metadata": {"language": "en"}},
)
)
async with asyncio.timeout(1):
await mock_client.transcribe_event.wait()
assert mock_client.transcribe is not None
assert mock_client.transcribe.language == "en"
# User started speaking
event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.STT_VAD_START, {"timestamp": 1234}
)
)
async with asyncio.timeout(1):
await mock_client.voice_started_event.wait()
assert mock_client.voice_started is not None
assert mock_client.voice_started.timestamp == 1234
# User stopped speaking
event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.STT_VAD_END, {"timestamp": 5678}
)
)
async with asyncio.timeout(1):
await mock_client.voice_stopped_event.wait()
assert mock_client.voice_stopped is not None
assert mock_client.voice_stopped.timestamp == 5678
# Speech-to-text transcription
event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.STT_END,
{"stt_output": {"text": "test transcript"}},
)
)
async with asyncio.timeout(1):
await mock_client.transcript_event.wait()
assert mock_client.transcript is not None
assert mock_client.transcript.text == "test transcript"
# Text-to-speech text
event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.TTS_START,
{
"tts_input": "test text to speak",
"voice": "test voice",
},
)
)
async with asyncio.timeout(1):
await mock_client.synthesize_event.wait()
assert mock_client.synthesize is not None
assert mock_client.synthesize.text == "test text to speak"
assert mock_client.synthesize.voice is not None
assert mock_client.synthesize.voice.name == "test voice"
# Text-to-speech media
event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.TTS_END,
{"tts_output": {"media_id": "test media id"}},
)
)
async with asyncio.timeout(1):
await mock_client.tts_audio_start_event.wait()
await mock_client.tts_audio_chunk_event.wait()
await mock_client.tts_audio_stop_event.wait()
# Verify audio chunk from test WAV
assert mock_client.tts_audio_chunk is not None
assert mock_client.tts_audio_chunk.rate == 22050
assert mock_client.tts_audio_chunk.width == 2
assert mock_client.tts_audio_chunk.channels == 1
assert mock_client.tts_audio_chunk.audio == b"123"
# Pipeline finished
event_callback(
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
)
assert not device.is_active
# Stop the satellite
await hass.config_entries.async_unload(entry.entry_id)
await hass.async_block_till_done()
async def test_satellite_disabled(hass: HomeAssistant) -> None:
"""Test callback for a satellite that has been disabled."""
on_disabled_event = asyncio.Event()
original_make_satellite = wyoming._make_satellite
def make_disabled_satellite(
hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService
):
satellite = original_make_satellite(hass, config_entry, service)
satellite.device.is_enabled = False
return satellite
async def on_disabled(self):
on_disabled_event.set()
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
), patch(
"homeassistant.components.wyoming._make_satellite", make_disabled_satellite
), patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_disabled",
on_disabled,
):
await setup_config_entry(hass)
async with asyncio.timeout(1):
await on_disabled_event.wait()
async def test_satellite_restart(hass: HomeAssistant) -> None:
"""Test pipeline loop restart after unexpected error."""
on_restart_event = asyncio.Event()
async def on_restart(self):
self.stop()
on_restart_event.set()
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
), patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite._run_once",
side_effect=RuntimeError(),
), patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
on_restart,
):
await setup_config_entry(hass)
async with asyncio.timeout(1):
await on_restart_event.wait()
async def test_satellite_reconnect(hass: HomeAssistant) -> None:
"""Test satellite reconnect call after connection refused."""
on_reconnect_event = asyncio.Event()
async def on_reconnect(self):
self.stop()
on_reconnect_event.set()
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
), patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient.connect",
side_effect=ConnectionRefusedError(),
), patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect",
on_reconnect,
):
await setup_config_entry(hass)
async with asyncio.timeout(1):
await on_reconnect_event.wait()
async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None:
"""Test satellite disconnecting before pipeline run."""
on_restart_event = asyncio.Event()
async def on_restart(self):
self.stop()
on_restart_event.set()
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
), patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
MockAsyncTcpClient([]), # no RunPipeline event
), patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
) as mock_run_pipeline, patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
on_restart,
):
await setup_config_entry(hass)
async with asyncio.timeout(1):
await on_restart_event.wait()
# Pipeline should never have run
mock_run_pipeline.assert_not_called()
async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None:
"""Test satellite disconnecting during pipeline run."""
events = [
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
] # no audio chunks after RunPipeline
on_restart_event = asyncio.Event()
on_stopped_event = asyncio.Event()
async def on_restart(self):
# Pretend sensor got stuck on
self.device.is_active = True
self.stop()
on_restart_event.set()
async def on_stopped(self):
on_stopped_event.set()
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
), patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
MockAsyncTcpClient(events),
), patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
) as mock_run_pipeline, patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
on_restart,
), patch(
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
on_stopped,
):
entry = await setup_config_entry(hass)
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
entry.entry_id
].satellite.device
async with asyncio.timeout(1):
await on_restart_event.wait()
await on_stopped_event.wait()
# Pipeline should have run once
mock_run_pipeline.assert_called_once()
# Sensor should have been turned off
assert not device.is_active

View File

@ -0,0 +1,83 @@
"""Test Wyoming select."""
from unittest.mock import Mock, patch
from homeassistant.components import assist_pipeline
from homeassistant.components.assist_pipeline.pipeline import PipelineData
from homeassistant.components.assist_pipeline.select import OPTION_PREFERRED
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
async def test_pipeline_select(
hass: HomeAssistant,
satellite_config_entry: ConfigEntry,
satellite_device: SatelliteDevice,
) -> None:
"""Test pipeline select.
Functionality is tested in assist_pipeline/test_select.py.
This test is only to ensure it is set up.
"""
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
pipeline_data: PipelineData = hass.data[assist_pipeline.DOMAIN]
# Create second pipeline
await pipeline_data.pipeline_store.async_create_item(
{
"name": "Test 1",
"language": "en-US",
"conversation_engine": None,
"conversation_language": "en-US",
"tts_engine": None,
"tts_language": None,
"tts_voice": None,
"stt_engine": None,
"stt_language": None,
"wake_word_entity": None,
"wake_word_id": None,
}
)
# Preferred pipeline is the default
pipeline_entity_id = satellite_device.get_pipeline_entity_id(hass)
assert pipeline_entity_id
state = hass.states.get(pipeline_entity_id)
assert state is not None
assert state.state == OPTION_PREFERRED
# Change to second pipeline
with patch.object(satellite_device, "set_pipeline_name") as mock_pipeline_changed:
await hass.services.async_call(
"select",
"select_option",
{"entity_id": pipeline_entity_id, "option": "Test 1"},
blocking=True,
)
state = hass.states.get(pipeline_entity_id)
assert state is not None
assert state.state == "Test 1"
# async_pipeline_changed should have been called
mock_pipeline_changed.assert_called_once_with("Test 1")
# Change back and check update listener
pipeline_listener = Mock()
satellite_device.set_pipeline_listener(pipeline_listener)
await hass.services.async_call(
"select",
"select_option",
{"entity_id": pipeline_entity_id, "option": OPTION_PREFERRED},
blocking=True,
)
state = hass.states.get(pipeline_entity_id)
assert state is not None
assert state.state == OPTION_PREFERRED
# listener should have been called
pipeline_listener.assert_called_once()

View File

@ -0,0 +1,32 @@
"""Test Wyoming switch devices."""
from homeassistant.components.wyoming.devices import SatelliteDevice
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant
async def test_satellite_enabled(
hass: HomeAssistant,
satellite_config_entry: ConfigEntry,
satellite_device: SatelliteDevice,
) -> None:
"""Test satellite enabled."""
satellite_enabled_id = satellite_device.get_satellite_enabled_entity_id(hass)
assert satellite_enabled_id
state = hass.states.get(satellite_enabled_id)
assert state is not None
assert state.state == STATE_ON
assert satellite_device.is_enabled
await hass.services.async_call(
"switch",
"turn_off",
{"entity_id": satellite_enabled_id},
blocking=True,
)
state = hass.states.get(satellite_enabled_id)
assert state is not None
assert state.state == STATE_OFF
assert not satellite_device.is_enabled