Set device id and forward errors to Wyoming satellites (#105266)

* Set device id and forward errors

* Fix tests
pull/105274/head
Michael Hansen 2023-12-07 19:44:43 -06:00 committed by GitHub
parent e9f8e7ab50
commit 43daeb2630
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 65 additions and 5 deletions

View File

@ -6,6 +6,6 @@
"dependencies": ["assist_pipeline"],
"documentation": "https://www.home-assistant.io/integrations/wyoming",
"iot_class": "local_push",
"requirements": ["wyoming==1.3.0"],
"requirements": ["wyoming==1.4.0"],
"zeroconf": ["_wyoming._tcp.local."]
}

View File

@ -9,6 +9,7 @@ import wave
from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop
from wyoming.client import AsyncTcpClient
from wyoming.error import Error
from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import RunSatellite
from wyoming.tts import Synthesize, SynthesizeVoice
@ -239,6 +240,7 @@ class WyomingSatellite:
auto_gain_dbfs=self.device.auto_gain,
volume_multiplier=self.device.volume_multiplier,
),
device_id=self.device.device_id,
)
)
@ -333,6 +335,16 @@ class WyomingSatellite:
if event.data and (tts_output := event.data["tts_output"]):
media_id = tts_output["media_id"]
self.hass.add_job(self._stream_tts(media_id))
elif event.type == assist_pipeline.PipelineEventType.ERROR:
# Pipeline error
if event.data:
self.hass.add_job(
self._client.write_event(
Error(
text=event.data["message"], code=event.data["code"]
).event()
)
)
async def _connect(self) -> None:
"""Connect to satellite over TCP."""

View File

@ -2760,7 +2760,7 @@ wled==0.17.0
wolf-smartset==0.1.11
# homeassistant.components.wyoming
wyoming==1.3.0
wyoming==1.4.0
# homeassistant.components.xbox
xbox-webapi==2.0.11

View File

@ -2064,7 +2064,7 @@ wled==0.17.0
wolf-smartset==0.1.11
# homeassistant.components.wyoming
wyoming==1.3.0
wyoming==1.4.0
# homeassistant.components.xbox
xbox-webapi==2.0.11

View File

@ -6,7 +6,7 @@
'language': 'en',
}),
'payload': None,
'type': 'transcibe',
'type': 'transcribe',
}),
dict({
'data': dict({

View File

@ -8,6 +8,7 @@ import wave
from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.error import Error
from wyoming.event import Event
from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import RunSatellite
@ -96,6 +97,9 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
self.tts_audio_stop_event = asyncio.Event()
self.tts_audio_chunk: AudioChunk | None = None
self.error_event = asyncio.Event()
self.error: Error | None = None
self._mic_audio_chunk = AudioChunk(
rate=16000, width=2, channels=1, audio=b"chunk"
).event()
@ -135,6 +139,9 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
self.tts_audio_chunk_event.set()
elif AudioStop.is_type(event.type):
self.tts_audio_stop_event.set()
elif Error.is_type(event.type):
self.error = Error.from_event(event)
self.error_event.set()
async def read_event(self) -> Event | None:
"""Receive."""
@ -175,8 +182,9 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
mock_run_pipeline.assert_called()
mock_run_pipeline.assert_called_once()
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
assert mock_run_pipeline.call_args.kwargs.get("device_id") == device.device_id
# Start detecting wake word
event_callback(
@ -458,3 +466,43 @@ async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None
# Sensor should have been turned off
assert not device.is_active
async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None:
"""Test satellite error occurring during pipeline run."""
events = [
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
] # no audio chunks after RunPipeline
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
), patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client, patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
) as mock_run_pipeline:
await setup_config_entry(hass)
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
mock_run_pipeline.assert_called_once()
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.ERROR,
{"code": "test code", "message": "test message"},
)
)
async with asyncio.timeout(1):
await mock_client.error_event.wait()
assert mock_client.error is not None
assert mock_client.error.text == "test message"
assert mock_client.error.code == "test code"