522 lines
18 KiB
Python
522 lines
18 KiB
Python
"""Test Wyoming satellite."""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import io
|
|
from unittest.mock import patch
|
|
import wave
|
|
|
|
from wyoming.asr import Transcribe, Transcript
|
|
from wyoming.audio import AudioChunk, AudioStart, AudioStop
|
|
from wyoming.error import Error
|
|
from wyoming.event import Event
|
|
from wyoming.pipeline import PipelineStage, RunPipeline
|
|
from wyoming.satellite import RunSatellite
|
|
from wyoming.tts import Synthesize
|
|
from wyoming.vad import VoiceStarted, VoiceStopped
|
|
from wyoming.wake import Detect, Detection
|
|
|
|
from homeassistant.components import assist_pipeline, wyoming
|
|
from homeassistant.components.wyoming.data import WyomingService
|
|
from homeassistant.components.wyoming.devices import SatelliteDevice
|
|
from homeassistant.config_entries import ConfigEntry
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.setup import async_setup_component
|
|
|
|
from . import SATELLITE_INFO, MockAsyncTcpClient
|
|
|
|
from tests.common import MockConfigEntry
|
|
|
|
|
|
async def setup_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
|
"""Set up config entry for Wyoming satellite.
|
|
|
|
This is separated from the satellite_config_entry method in conftest.py so
|
|
we can patch functions before the satellite task is run during setup.
|
|
"""
|
|
entry = MockConfigEntry(
|
|
domain="wyoming",
|
|
data={
|
|
"host": "1.2.3.4",
|
|
"port": 1234,
|
|
},
|
|
title="Test Satellite",
|
|
)
|
|
entry.add_to_hass(hass)
|
|
await hass.config_entries.async_setup(entry.entry_id)
|
|
await hass.async_block_till_done()
|
|
|
|
return entry
|
|
|
|
|
|
def get_test_wav() -> bytes:
|
|
"""Get bytes for test WAV file."""
|
|
with io.BytesIO() as wav_io:
|
|
with wave.open(wav_io, "wb") as wav_file:
|
|
wav_file.setframerate(22050)
|
|
wav_file.setsampwidth(2)
|
|
wav_file.setnchannels(1)
|
|
|
|
# Single frame
|
|
wav_file.writeframes(b"123")
|
|
|
|
return wav_io.getvalue()
|
|
|
|
|
|
class SatelliteAsyncTcpClient(MockAsyncTcpClient):
|
|
"""Satellite AsyncTcpClient."""
|
|
|
|
def __init__(self, responses: list[Event]) -> None:
|
|
"""Initialize client."""
|
|
super().__init__(responses)
|
|
|
|
self.connect_event = asyncio.Event()
|
|
self.run_satellite_event = asyncio.Event()
|
|
self.detect_event = asyncio.Event()
|
|
|
|
self.detection_event = asyncio.Event()
|
|
self.detection: Detection | None = None
|
|
|
|
self.transcribe_event = asyncio.Event()
|
|
self.transcribe: Transcribe | None = None
|
|
|
|
self.voice_started_event = asyncio.Event()
|
|
self.voice_started: VoiceStarted | None = None
|
|
|
|
self.voice_stopped_event = asyncio.Event()
|
|
self.voice_stopped: VoiceStopped | None = None
|
|
|
|
self.transcript_event = asyncio.Event()
|
|
self.transcript: Transcript | None = None
|
|
|
|
self.synthesize_event = asyncio.Event()
|
|
self.synthesize: Synthesize | None = None
|
|
|
|
self.tts_audio_start_event = asyncio.Event()
|
|
self.tts_audio_chunk_event = asyncio.Event()
|
|
self.tts_audio_stop_event = asyncio.Event()
|
|
self.tts_audio_chunk: AudioChunk | None = None
|
|
|
|
self.error_event = asyncio.Event()
|
|
self.error: Error | None = None
|
|
|
|
self._mic_audio_chunk = AudioChunk(
|
|
rate=16000, width=2, channels=1, audio=b"chunk"
|
|
).event()
|
|
|
|
async def connect(self) -> None:
|
|
"""Connect."""
|
|
self.connect_event.set()
|
|
|
|
async def write_event(self, event: Event):
|
|
"""Send."""
|
|
if RunSatellite.is_type(event.type):
|
|
self.run_satellite_event.set()
|
|
elif Detect.is_type(event.type):
|
|
self.detect_event.set()
|
|
elif Detection.is_type(event.type):
|
|
self.detection = Detection.from_event(event)
|
|
self.detection_event.set()
|
|
elif Transcribe.is_type(event.type):
|
|
self.transcribe = Transcribe.from_event(event)
|
|
self.transcribe_event.set()
|
|
elif VoiceStarted.is_type(event.type):
|
|
self.voice_started = VoiceStarted.from_event(event)
|
|
self.voice_started_event.set()
|
|
elif VoiceStopped.is_type(event.type):
|
|
self.voice_stopped = VoiceStopped.from_event(event)
|
|
self.voice_stopped_event.set()
|
|
elif Transcript.is_type(event.type):
|
|
self.transcript = Transcript.from_event(event)
|
|
self.transcript_event.set()
|
|
elif Synthesize.is_type(event.type):
|
|
self.synthesize = Synthesize.from_event(event)
|
|
self.synthesize_event.set()
|
|
elif AudioStart.is_type(event.type):
|
|
self.tts_audio_start_event.set()
|
|
elif AudioChunk.is_type(event.type):
|
|
self.tts_audio_chunk = AudioChunk.from_event(event)
|
|
self.tts_audio_chunk_event.set()
|
|
elif AudioStop.is_type(event.type):
|
|
self.tts_audio_stop_event.set()
|
|
elif Error.is_type(event.type):
|
|
self.error = Error.from_event(event)
|
|
self.error_event.set()
|
|
|
|
async def read_event(self) -> Event | None:
|
|
"""Receive."""
|
|
event = await super().read_event()
|
|
|
|
# Keep sending audio chunks instead of None
|
|
return event or self._mic_audio_chunk
|
|
|
|
|
|
async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
|
"""Test running a pipeline with a satellite."""
|
|
assert await async_setup_component(hass, assist_pipeline.DOMAIN, {})
|
|
|
|
events = [
|
|
RunPipeline(
|
|
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
|
).event(),
|
|
]
|
|
|
|
with patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
), patch(
|
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
|
SatelliteAsyncTcpClient(events),
|
|
) as mock_client, patch(
|
|
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
|
) as mock_run_pipeline, patch(
|
|
"homeassistant.components.wyoming.satellite.tts.async_get_media_source_audio",
|
|
return_value=("wav", get_test_wav()),
|
|
):
|
|
entry = await setup_config_entry(hass)
|
|
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
|
entry.entry_id
|
|
].satellite.device
|
|
|
|
async with asyncio.timeout(1):
|
|
await mock_client.connect_event.wait()
|
|
await mock_client.run_satellite_event.wait()
|
|
|
|
mock_run_pipeline.assert_called_once()
|
|
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
|
|
assert mock_run_pipeline.call_args.kwargs.get("device_id") == device.device_id
|
|
|
|
# Start detecting wake word
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
assist_pipeline.PipelineEventType.WAKE_WORD_START
|
|
)
|
|
)
|
|
async with asyncio.timeout(1):
|
|
await mock_client.detect_event.wait()
|
|
|
|
assert not device.is_active
|
|
assert not device.is_muted
|
|
|
|
# Wake word is detected
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
assist_pipeline.PipelineEventType.WAKE_WORD_END,
|
|
{"wake_word_output": {"wake_word_id": "test_wake_word"}},
|
|
)
|
|
)
|
|
async with asyncio.timeout(1):
|
|
await mock_client.detection_event.wait()
|
|
|
|
assert mock_client.detection is not None
|
|
assert mock_client.detection.name == "test_wake_word"
|
|
|
|
# "Assist in progress" sensor should be active now
|
|
assert device.is_active
|
|
|
|
# Speech-to-text started
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
assist_pipeline.PipelineEventType.STT_START,
|
|
{"metadata": {"language": "en"}},
|
|
)
|
|
)
|
|
async with asyncio.timeout(1):
|
|
await mock_client.transcribe_event.wait()
|
|
|
|
assert mock_client.transcribe is not None
|
|
assert mock_client.transcribe.language == "en"
|
|
|
|
# User started speaking
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
assist_pipeline.PipelineEventType.STT_VAD_START, {"timestamp": 1234}
|
|
)
|
|
)
|
|
async with asyncio.timeout(1):
|
|
await mock_client.voice_started_event.wait()
|
|
|
|
assert mock_client.voice_started is not None
|
|
assert mock_client.voice_started.timestamp == 1234
|
|
|
|
# User stopped speaking
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
assist_pipeline.PipelineEventType.STT_VAD_END, {"timestamp": 5678}
|
|
)
|
|
)
|
|
async with asyncio.timeout(1):
|
|
await mock_client.voice_stopped_event.wait()
|
|
|
|
assert mock_client.voice_stopped is not None
|
|
assert mock_client.voice_stopped.timestamp == 5678
|
|
|
|
# Speech-to-text transcription
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
assist_pipeline.PipelineEventType.STT_END,
|
|
{"stt_output": {"text": "test transcript"}},
|
|
)
|
|
)
|
|
async with asyncio.timeout(1):
|
|
await mock_client.transcript_event.wait()
|
|
|
|
assert mock_client.transcript is not None
|
|
assert mock_client.transcript.text == "test transcript"
|
|
|
|
# Text-to-speech text
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
assist_pipeline.PipelineEventType.TTS_START,
|
|
{
|
|
"tts_input": "test text to speak",
|
|
"voice": "test voice",
|
|
},
|
|
)
|
|
)
|
|
async with asyncio.timeout(1):
|
|
await mock_client.synthesize_event.wait()
|
|
|
|
assert mock_client.synthesize is not None
|
|
assert mock_client.synthesize.text == "test text to speak"
|
|
assert mock_client.synthesize.voice is not None
|
|
assert mock_client.synthesize.voice.name == "test voice"
|
|
|
|
# Text-to-speech media
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
assist_pipeline.PipelineEventType.TTS_END,
|
|
{"tts_output": {"media_id": "test media id"}},
|
|
)
|
|
)
|
|
async with asyncio.timeout(1):
|
|
await mock_client.tts_audio_start_event.wait()
|
|
await mock_client.tts_audio_chunk_event.wait()
|
|
await mock_client.tts_audio_stop_event.wait()
|
|
|
|
# Verify audio chunk from test WAV
|
|
assert mock_client.tts_audio_chunk is not None
|
|
assert mock_client.tts_audio_chunk.rate == 22050
|
|
assert mock_client.tts_audio_chunk.width == 2
|
|
assert mock_client.tts_audio_chunk.channels == 1
|
|
assert mock_client.tts_audio_chunk.audio == b"123"
|
|
|
|
# Pipeline finished
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(assist_pipeline.PipelineEventType.RUN_END)
|
|
)
|
|
assert not device.is_active
|
|
|
|
# Stop the satellite
|
|
await hass.config_entries.async_unload(entry.entry_id)
|
|
await hass.async_block_till_done()
|
|
|
|
|
|
async def test_satellite_muted(hass: HomeAssistant) -> None:
|
|
"""Test callback for a satellite that has been muted."""
|
|
on_muted_event = asyncio.Event()
|
|
|
|
original_make_satellite = wyoming._make_satellite
|
|
|
|
def make_muted_satellite(
|
|
hass: HomeAssistant, config_entry: ConfigEntry, service: WyomingService
|
|
):
|
|
satellite = original_make_satellite(hass, config_entry, service)
|
|
satellite.device.set_is_muted(True)
|
|
|
|
return satellite
|
|
|
|
async def on_muted(self):
|
|
self.device.set_is_muted(False)
|
|
on_muted_event.set()
|
|
|
|
with patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
), patch(
|
|
"homeassistant.components.wyoming._make_satellite", make_muted_satellite
|
|
), patch(
|
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_muted",
|
|
on_muted,
|
|
):
|
|
await setup_config_entry(hass)
|
|
async with asyncio.timeout(1):
|
|
await on_muted_event.wait()
|
|
|
|
|
|
async def test_satellite_restart(hass: HomeAssistant) -> None:
|
|
"""Test pipeline loop restart after unexpected error."""
|
|
on_restart_event = asyncio.Event()
|
|
|
|
async def on_restart(self):
|
|
self.stop()
|
|
on_restart_event.set()
|
|
|
|
with patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
), patch(
|
|
"homeassistant.components.wyoming.satellite.WyomingSatellite._run_once",
|
|
side_effect=RuntimeError(),
|
|
), patch(
|
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
|
on_restart,
|
|
):
|
|
await setup_config_entry(hass)
|
|
async with asyncio.timeout(1):
|
|
await on_restart_event.wait()
|
|
|
|
|
|
async def test_satellite_reconnect(hass: HomeAssistant) -> None:
|
|
"""Test satellite reconnect call after connection refused."""
|
|
num_reconnects = 0
|
|
reconnect_event = asyncio.Event()
|
|
stopped_event = asyncio.Event()
|
|
|
|
async def on_reconnect(self):
|
|
nonlocal num_reconnects
|
|
num_reconnects += 1
|
|
if num_reconnects >= 2:
|
|
reconnect_event.set()
|
|
self.stop()
|
|
|
|
async def on_stopped(self):
|
|
stopped_event.set()
|
|
|
|
with patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
), patch(
|
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient.connect",
|
|
side_effect=ConnectionRefusedError(),
|
|
), patch(
|
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_reconnect",
|
|
on_reconnect,
|
|
), patch(
|
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
|
|
on_stopped,
|
|
):
|
|
await setup_config_entry(hass)
|
|
async with asyncio.timeout(1):
|
|
await reconnect_event.wait()
|
|
await stopped_event.wait()
|
|
|
|
|
|
async def test_satellite_disconnect_before_pipeline(hass: HomeAssistant) -> None:
|
|
"""Test satellite disconnecting before pipeline run."""
|
|
on_restart_event = asyncio.Event()
|
|
|
|
async def on_restart(self):
|
|
self.stop()
|
|
on_restart_event.set()
|
|
|
|
with patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
), patch(
|
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
|
MockAsyncTcpClient([]), # no RunPipeline event
|
|
), patch(
|
|
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
|
) as mock_run_pipeline, patch(
|
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
|
on_restart,
|
|
):
|
|
await setup_config_entry(hass)
|
|
async with asyncio.timeout(1):
|
|
await on_restart_event.wait()
|
|
|
|
# Pipeline should never have run
|
|
mock_run_pipeline.assert_not_called()
|
|
|
|
|
|
async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None:
|
|
"""Test satellite disconnecting during pipeline run."""
|
|
events = [
|
|
RunPipeline(
|
|
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
|
).event(),
|
|
] # no audio chunks after RunPipeline
|
|
|
|
on_restart_event = asyncio.Event()
|
|
on_stopped_event = asyncio.Event()
|
|
|
|
async def on_restart(self):
|
|
# Pretend sensor got stuck on
|
|
self.device.is_active = True
|
|
self.stop()
|
|
on_restart_event.set()
|
|
|
|
async def on_stopped(self):
|
|
on_stopped_event.set()
|
|
|
|
with patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
), patch(
|
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
|
MockAsyncTcpClient(events),
|
|
), patch(
|
|
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
|
) as mock_run_pipeline, patch(
|
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_restart",
|
|
on_restart,
|
|
), patch(
|
|
"homeassistant.components.wyoming.satellite.WyomingSatellite.on_stopped",
|
|
on_stopped,
|
|
):
|
|
entry = await setup_config_entry(hass)
|
|
device: SatelliteDevice = hass.data[wyoming.DOMAIN][
|
|
entry.entry_id
|
|
].satellite.device
|
|
|
|
async with asyncio.timeout(1):
|
|
await on_restart_event.wait()
|
|
await on_stopped_event.wait()
|
|
|
|
# Pipeline should have run once
|
|
mock_run_pipeline.assert_called_once()
|
|
|
|
# Sensor should have been turned off
|
|
assert not device.is_active
|
|
|
|
|
|
async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None:
|
|
"""Test satellite error occurring during pipeline run."""
|
|
events = [
|
|
RunPipeline(
|
|
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
|
|
).event(),
|
|
] # no audio chunks after RunPipeline
|
|
|
|
with patch(
|
|
"homeassistant.components.wyoming.data.load_wyoming_info",
|
|
return_value=SATELLITE_INFO,
|
|
), patch(
|
|
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
|
|
SatelliteAsyncTcpClient(events),
|
|
) as mock_client, patch(
|
|
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
|
|
) as mock_run_pipeline:
|
|
await setup_config_entry(hass)
|
|
|
|
async with asyncio.timeout(1):
|
|
await mock_client.connect_event.wait()
|
|
await mock_client.run_satellite_event.wait()
|
|
|
|
mock_run_pipeline.assert_called_once()
|
|
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
|
|
event_callback(
|
|
assist_pipeline.PipelineEvent(
|
|
assist_pipeline.PipelineEventType.ERROR,
|
|
{"code": "test code", "message": "test message"},
|
|
)
|
|
)
|
|
|
|
async with asyncio.timeout(1):
|
|
await mock_client.error_event.wait()
|
|
|
|
assert mock_client.error is not None
|
|
assert mock_client.error.text == "test message"
|
|
assert mock_client.error.code == "test code"
|