Add STT error code for cloud authentication failure (#133170)

pull/132975/head
Michael Hansen 2024-12-13 13:59:46 -06:00 committed by GitHub
parent e13fa8346a
commit 50b897bdaa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 83 additions and 0 deletions

View File

@ -16,6 +16,7 @@ import time
from typing import Any, Literal, cast
import wave
import hass_nabucasa
import voluptuous as vol
from homeassistant.components import (
@ -918,6 +919,11 @@ class PipelineRun:
)
except (asyncio.CancelledError, TimeoutError):
raise # expected
except hass_nabucasa.auth.Unauthenticated as src_error:
raise SpeechToTextError(
code="cloud-auth-failed",
message="Home Assistant Cloud authentication failed",
) from src_error
except Exception as src_error:
_LOGGER.exception("Unexpected error during speech-to-text")
raise SpeechToTextError(

View File

@ -387,6 +387,42 @@
}),
])
# ---
# name: test_pipeline_from_audio_stream_with_cloud_auth_fail
list([
dict({
'data': dict({
'language': 'en',
'pipeline': <ANY>,
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
dict({
'data': dict({
'engine': 'stt.mock_stt',
'metadata': dict({
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
'channel': <AudioChannels.CHANNEL_MONO: 1>,
'codec': <AudioCodecs.PCM: 'pcm'>,
'format': <AudioFormats.WAV: 'wav'>,
'language': 'en-US',
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
}),
}),
'type': <PipelineEventType.STT_START: 'stt-start'>,
}),
dict({
'data': dict({
'code': 'cloud-auth-failed',
'message': 'Home Assistant Cloud authentication failed',
}),
'type': <PipelineEventType.ERROR: 'error'>,
}),
dict({
'data': None,
'type': <PipelineEventType.RUN_END: 'run-end'>,
}),
])
# ---
# name: test_pipeline_language_used_instead_of_conversation_language
list([
dict({

View File

@ -8,6 +8,7 @@ import tempfile
from unittest.mock import ANY, patch
import wave
import hass_nabucasa
import pytest
from syrupy.assertion import SnapshotAssertion
@ -1173,3 +1174,43 @@ async def test_pipeline_language_used_instead_of_conversation_language(
mock_async_converse.call_args_list[0].kwargs.get("language")
== pipeline.language
)
async def test_pipeline_from_audio_stream_with_cloud_auth_fail(
hass: HomeAssistant,
mock_stt_provider_entity: MockSTTProviderEntity,
init_components,
snapshot: SnapshotAssertion,
) -> None:
"""Test creating a pipeline from an audio stream but the cloud authentication fails."""
events: list[assist_pipeline.PipelineEvent] = []
async def audio_data():
yield b"audio"
with patch.object(
mock_stt_provider_entity,
"async_process_audio_stream",
side_effect=hass_nabucasa.auth.Unauthenticated,
):
await assist_pipeline.async_pipeline_from_audio_stream(
hass,
context=Context(),
event_callback=events.append,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
)
assert process_events(events) == snapshot
assert len(events) == 4 # run start, stt start, error, run end
assert events[2].type == assist_pipeline.PipelineEventType.ERROR
assert events[2].data["code"] == "cloud-auth-failed"