Migrate Assist Pipeline to use TTS stream (#139542)
* Migrate Pipeline to use TTS stream * Fix testspull/139598/head
parent
c168695323
commit
2cce1b024e
|
@ -19,14 +19,7 @@ import wave
|
|||
import hass_nabucasa
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import (
|
||||
conversation,
|
||||
media_source,
|
||||
stt,
|
||||
tts,
|
||||
wake_word,
|
||||
websocket_api,
|
||||
)
|
||||
from homeassistant.components import conversation, stt, tts, wake_word, websocket_api
|
||||
from homeassistant.components.tts import (
|
||||
generate_media_source_id as tts_generate_media_source_id,
|
||||
)
|
||||
|
@ -569,8 +562,7 @@ class PipelineRun:
|
|||
|
||||
id: str = field(default_factory=ulid_util.ulid_now)
|
||||
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False, repr=False)
|
||||
tts_engine: str = field(init=False, repr=False)
|
||||
tts_options: dict | None = field(init=False, default=None)
|
||||
tts_stream: tts.ResultStream | None = field(init=False, default=None)
|
||||
wake_word_entity_id: str | None = field(init=False, default=None, repr=False)
|
||||
wake_word_entity: wake_word.WakeWordDetectionEntity = field(init=False, repr=False)
|
||||
|
||||
|
@ -648,13 +640,18 @@ class PipelineRun:
|
|||
self._device_id = device_id
|
||||
self._start_debug_recording_thread()
|
||||
|
||||
data = {
|
||||
data: dict[str, Any] = {
|
||||
"pipeline": self.pipeline.id,
|
||||
"language": self.language,
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
if self.runner_data is not None:
|
||||
data["runner_data"] = self.runner_data
|
||||
if self.tts_stream:
|
||||
data["tts_output"] = {
|
||||
"url": self.tts_stream.url,
|
||||
"mime_type": self.tts_stream.content_type,
|
||||
}
|
||||
|
||||
self.process_event(PipelineEvent(PipelineEventType.RUN_START, data))
|
||||
|
||||
|
@ -1246,36 +1243,31 @@ class PipelineRun:
|
|||
tts_options[tts.ATTR_PREFERRED_SAMPLE_BYTES] = SAMPLE_WIDTH
|
||||
|
||||
try:
|
||||
options_supported = await tts.async_support_options(
|
||||
self.hass,
|
||||
engine,
|
||||
self.pipeline.tts_language,
|
||||
tts_options,
|
||||
self.tts_stream = tts.async_create_stream(
|
||||
hass=self.hass,
|
||||
engine=engine,
|
||||
language=self.pipeline.tts_language,
|
||||
options=tts_options,
|
||||
)
|
||||
except HomeAssistantError as err:
|
||||
raise TextToSpeechError(
|
||||
code="tts-not-supported",
|
||||
message=f"Text-to-speech engine '{engine}' not found",
|
||||
) from err
|
||||
if not options_supported:
|
||||
raise TextToSpeechError(
|
||||
code="tts-not-supported",
|
||||
message=(
|
||||
f"Text-to-speech engine {engine} "
|
||||
f"does not support language {self.pipeline.tts_language} or options {tts_options}"
|
||||
f"does not support language {self.pipeline.tts_language} or options {tts_options}:"
|
||||
f" {err}"
|
||||
),
|
||||
)
|
||||
|
||||
self.tts_engine = engine
|
||||
self.tts_options = tts_options
|
||||
) from err
|
||||
|
||||
async def text_to_speech(self, tts_input: str) -> None:
|
||||
"""Run text-to-speech portion of pipeline."""
|
||||
assert self.tts_stream is not None
|
||||
|
||||
self.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.TTS_START,
|
||||
{
|
||||
"engine": self.tts_engine,
|
||||
"engine": self.tts_stream.engine,
|
||||
"language": self.pipeline.tts_language,
|
||||
"voice": self.pipeline.tts_voice,
|
||||
"tts_input": tts_input,
|
||||
|
@ -1288,14 +1280,9 @@ class PipelineRun:
|
|||
tts_media_id = tts_generate_media_source_id(
|
||||
self.hass,
|
||||
tts_input,
|
||||
engine=self.tts_engine,
|
||||
language=self.pipeline.tts_language,
|
||||
options=self.tts_options,
|
||||
)
|
||||
tts_media = await media_source.async_resolve_media(
|
||||
self.hass,
|
||||
tts_media_id,
|
||||
None,
|
||||
engine=self.tts_stream.engine,
|
||||
language=self.tts_stream.language,
|
||||
options=self.tts_stream.options,
|
||||
)
|
||||
except Exception as src_error:
|
||||
_LOGGER.exception("Unexpected error during text-to-speech")
|
||||
|
@ -1304,10 +1291,12 @@ class PipelineRun:
|
|||
message="Unexpected error during text-to-speech",
|
||||
) from src_error
|
||||
|
||||
_LOGGER.debug("TTS result %s", tts_media)
|
||||
self.tts_stream.async_set_message(tts_input)
|
||||
|
||||
tts_output = {
|
||||
"media_id": tts_media_id,
|
||||
**asdict(tts_media),
|
||||
"url": self.tts_stream.url,
|
||||
"mime_type": self.tts_stream.content_type,
|
||||
}
|
||||
|
||||
self.process_event(
|
||||
|
|
|
@ -79,13 +79,13 @@ __all__ = [
|
|||
"PLATFORM_SCHEMA",
|
||||
"PLATFORM_SCHEMA_BASE",
|
||||
"Provider",
|
||||
"ResultStream",
|
||||
"SampleFormat",
|
||||
"TextToSpeechEntity",
|
||||
"TtsAudioType",
|
||||
"Voice",
|
||||
"async_default_engine",
|
||||
"async_get_media_source_audio",
|
||||
"async_support_options",
|
||||
"generate_media_source_id",
|
||||
]
|
||||
|
||||
|
@ -167,22 +167,19 @@ def async_resolve_engine(hass: HomeAssistant, engine: str | None) -> str | None:
|
|||
return async_default_engine(hass)
|
||||
|
||||
|
||||
async def async_support_options(
|
||||
@callback
|
||||
def async_create_stream(
|
||||
hass: HomeAssistant,
|
||||
engine: str,
|
||||
language: str | None = None,
|
||||
options: dict | None = None,
|
||||
) -> bool:
|
||||
"""Return if an engine supports options."""
|
||||
if (engine_instance := get_engine_instance(hass, engine)) is None:
|
||||
raise HomeAssistantError(f"Provider {engine} not found")
|
||||
|
||||
try:
|
||||
hass.data[DATA_TTS_MANAGER].process_options(engine_instance, language, options)
|
||||
except HomeAssistantError:
|
||||
return False
|
||||
|
||||
return True
|
||||
) -> ResultStream:
|
||||
"""Create a streaming URL where the rendered TTS can be retrieved."""
|
||||
return hass.data[DATA_TTS_MANAGER].async_create_result_stream(
|
||||
engine=engine,
|
||||
language=language,
|
||||
options=options,
|
||||
)
|
||||
|
||||
|
||||
async def async_get_media_source_audio(
|
||||
|
@ -407,6 +404,18 @@ class ResultStream:
|
|||
"""Set cache key for message to be streamed."""
|
||||
self._result_cache_key.set_result(cache_key)
|
||||
|
||||
@callback
|
||||
def async_set_message(self, message: str) -> None:
|
||||
"""Set message to be generated."""
|
||||
cache_key = self._manager.async_cache_message_in_memory(
|
||||
engine=self.engine,
|
||||
message=message,
|
||||
use_file_cache=self.use_file_cache,
|
||||
language=self.language,
|
||||
options=self.options,
|
||||
)
|
||||
self._result_cache_key.set_result(cache_key)
|
||||
|
||||
async def async_stream_result(self) -> AsyncGenerator[bytes]:
|
||||
"""Get the stream of this result."""
|
||||
cache_key = await self._result_cache_key
|
||||
|
|
|
@ -6,6 +6,10 @@
|
|||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
|
@ -99,6 +103,10 @@
|
|||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
|
@ -192,6 +200,10 @@
|
|||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
|
@ -285,6 +297,10 @@
|
|||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
|
@ -402,6 +418,10 @@
|
|||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
|
@ -598,6 +618,10 @@
|
|||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
|
|
|
@ -8,6 +8,10 @@
|
|||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline.1
|
||||
|
@ -93,6 +97,10 @@
|
|||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_debug.1
|
||||
|
@ -190,6 +198,10 @@
|
|||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_enhancements.1
|
||||
|
@ -275,6 +287,10 @@
|
|||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout.1
|
||||
|
@ -382,6 +398,10 @@
|
|||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_timeout.1
|
||||
|
@ -585,6 +605,10 @@
|
|||
'stt_binary_handler_id': None,
|
||||
'timeout': 300,
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_pipeline_empty_tts_output.1
|
||||
|
@ -634,6 +658,10 @@
|
|||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_stt_cooldown_different_ids.1
|
||||
|
@ -645,6 +673,10 @@
|
|||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_stt_cooldown_same_id
|
||||
|
@ -656,6 +688,10 @@
|
|||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_stt_cooldown_same_id.1
|
||||
|
@ -667,6 +703,10 @@
|
|||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_stt_stream_failed
|
||||
|
@ -678,6 +718,10 @@
|
|||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_stt_stream_failed.1
|
||||
|
@ -798,28 +842,6 @@
|
|||
'message': 'Timeout running pipeline',
|
||||
})
|
||||
# ---
|
||||
# name: test_tts_failed
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
'stt_binary_handler_id': None,
|
||||
'timeout': 300,
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_tts_failed.1
|
||||
dict({
|
||||
'engine': 'test',
|
||||
'language': 'en-US',
|
||||
'tts_input': 'Lights are on.',
|
||||
'voice': 'james_earl_jones',
|
||||
})
|
||||
# ---
|
||||
# name: test_tts_failed.2
|
||||
None
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_different_entities
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
|
@ -829,6 +851,10 @@
|
|||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_different_entities.1
|
||||
|
@ -840,6 +866,10 @@
|
|||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_different_entities.2
|
||||
|
@ -892,6 +922,10 @@
|
|||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_different_ids.1
|
||||
|
@ -903,6 +937,10 @@
|
|||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_different_ids.2
|
||||
|
@ -958,6 +996,10 @@
|
|||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_same_id.1
|
||||
|
@ -969,6 +1011,10 @@
|
|||
'stt_binary_handler_id': 1,
|
||||
'timeout': 300,
|
||||
}),
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
})
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_same_id.2
|
||||
|
|
|
@ -43,13 +43,21 @@ from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
|||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_ulid() -> Generator[Mock]:
|
||||
"""Mock the ulid of chat sessions."""
|
||||
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
|
||||
mock_ulid_now.return_value = "mock-ulid"
|
||||
def mock_chat_session_id() -> Generator[Mock]:
|
||||
"""Mock the conversation ID of chat sessions."""
|
||||
with patch(
|
||||
"homeassistant.helpers.chat_session.ulid_now", return_value="mock-ulid"
|
||||
) as mock_ulid_now:
|
||||
yield mock_ulid_now
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_tts_token() -> Generator[None]:
|
||||
"""Mock the TTS token for URLs."""
|
||||
with patch("secrets.token_urlsafe", return_value="mocked-token"):
|
||||
yield
|
||||
|
||||
|
||||
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
|
||||
"""Process events to remove dynamic values."""
|
||||
processed = []
|
||||
|
@ -797,10 +805,16 @@ async def test_tts_audio_output(
|
|||
await pipeline_input.validate()
|
||||
|
||||
# Verify TTS audio settings
|
||||
assert pipeline_input.run.tts_options is not None
|
||||
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
|
||||
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_RATE) == 16000
|
||||
assert pipeline_input.run.tts_options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS) == 1
|
||||
assert pipeline_input.run.tts_stream.options is not None
|
||||
assert pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
|
||||
assert (
|
||||
pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)
|
||||
== 16000
|
||||
)
|
||||
assert (
|
||||
pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)
|
||||
== 1
|
||||
)
|
||||
|
||||
with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio:
|
||||
await pipeline_input.execute()
|
||||
|
@ -809,9 +823,7 @@ async def test_tts_audio_output(
|
|||
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||
# We must fetch the media URL to trigger the TTS
|
||||
assert event.data
|
||||
media_id = event.data["tts_output"]["media_id"]
|
||||
resolved = await media_source.async_resolve_media(hass, media_id, None)
|
||||
await client.get(resolved.url)
|
||||
await client.get(event.data["tts_output"]["url"])
|
||||
|
||||
# Ensure that no unsupported options were passed in
|
||||
assert mock_get_tts_audio.called
|
||||
|
@ -875,9 +887,7 @@ async def test_tts_wav_preferred_format(
|
|||
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||
# We must fetch the media URL to trigger the TTS
|
||||
assert event.data
|
||||
media_id = event.data["tts_output"]["media_id"]
|
||||
resolved = await media_source.async_resolve_media(hass, media_id, None)
|
||||
await client.get(resolved.url)
|
||||
await client.get(event.data["tts_output"]["url"])
|
||||
|
||||
assert mock_get_tts_audio.called
|
||||
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
||||
|
@ -949,9 +959,7 @@ async def test_tts_dict_preferred_format(
|
|||
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||
# We must fetch the media URL to trigger the TTS
|
||||
assert event.data
|
||||
media_id = event.data["tts_output"]["media_id"]
|
||||
resolved = await media_source.async_resolve_media(hass, media_id, None)
|
||||
await client.get(resolved.url)
|
||||
await client.get(event.data["tts_output"]["url"])
|
||||
|
||||
assert mock_get_tts_audio.called
|
||||
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
||||
|
|
|
@ -20,6 +20,8 @@ from homeassistant.components.assist_pipeline.pipeline import (
|
|||
DeviceAudioQueue,
|
||||
Pipeline,
|
||||
PipelineData,
|
||||
async_get_pipelines,
|
||||
async_update_pipeline,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
@ -38,13 +40,21 @@ from tests.typing import WebSocketGenerator
|
|||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_ulid() -> Generator[Mock]:
|
||||
"""Mock the ulid of chat sessions."""
|
||||
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
|
||||
mock_ulid_now.return_value = "mock-ulid"
|
||||
def mock_chat_session_id() -> Generator[Mock]:
|
||||
"""Mock the conversation ID of chat sessions."""
|
||||
with patch(
|
||||
"homeassistant.helpers.chat_session.ulid_now", return_value="mock-ulid"
|
||||
) as mock_ulid_now:
|
||||
yield mock_ulid_now
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_tts_token() -> Generator[None]:
|
||||
"""Mock the TTS token for URLs."""
|
||||
with patch("secrets.token_urlsafe", return_value="mocked-token"):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"extra_msg",
|
||||
[
|
||||
|
@ -825,74 +835,6 @@ async def test_stt_stream_failed(
|
|||
assert msg["result"] == {"events": events}
|
||||
|
||||
|
||||
async def test_tts_failed(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test pipeline run with text-to-speech error."""
|
||||
events = []
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.media_source.async_resolve_media",
|
||||
side_effect=RuntimeError,
|
||||
):
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "tts",
|
||||
"end_stage": "tts",
|
||||
"input": {"text": "Lights are on."},
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
|
||||
# run start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-start"
|
||||
msg["event"]["data"]["pipeline"] = ANY
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# tts start
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "tts-start"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
# tts error
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "error"
|
||||
assert msg["event"]["data"]["code"] == "tts-failed"
|
||||
events.append(msg["event"])
|
||||
|
||||
# run end
|
||||
msg = await client.receive_json()
|
||||
assert msg["event"]["type"] == "run-end"
|
||||
assert msg["event"]["data"] == snapshot
|
||||
events.append(msg["event"])
|
||||
|
||||
pipeline_data: PipelineData = hass.data[DOMAIN]
|
||||
pipeline_id = list(pipeline_data.pipeline_debug)[0]
|
||||
pipeline_run_id = list(pipeline_data.pipeline_debug[pipeline_id])[0]
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline_debug/get",
|
||||
"pipeline_id": pipeline_id,
|
||||
"pipeline_run_id": pipeline_run_id,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
assert msg["result"] == {"events": events}
|
||||
|
||||
|
||||
async def test_tts_provider_missing(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
|
@ -903,23 +845,22 @@ async def test_tts_provider_missing(
|
|||
"""Test pipeline run with text-to-speech error."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.tts.async_support_options",
|
||||
side_effect=HomeAssistantError,
|
||||
):
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "tts",
|
||||
"end_stage": "tts",
|
||||
"input": {"text": "Lights are on."},
|
||||
}
|
||||
)
|
||||
pipelines = async_get_pipelines(hass)
|
||||
await async_update_pipeline(hass, pipelines[0], tts_engine="unavailable")
|
||||
|
||||
# result
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == "tts-not-supported"
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/run",
|
||||
"start_stage": "tts",
|
||||
"end_stage": "tts",
|
||||
"input": {"text": "Lights are on."},
|
||||
}
|
||||
)
|
||||
|
||||
# result
|
||||
msg = await client.receive_json()
|
||||
assert not msg["success"]
|
||||
assert msg["error"]["code"] == "tts-not-supported"
|
||||
|
||||
|
||||
async def test_tts_provider_bad_options(
|
||||
|
@ -933,8 +874,8 @@ async def test_tts_provider_bad_options(
|
|||
client = await hass_ws_client(hass)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.tts.async_support_options",
|
||||
return_value=False,
|
||||
"homeassistant.components.tts.SpeechManager.process_options",
|
||||
side_effect=HomeAssistantError("Language not supported"),
|
||||
):
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
|
|
|
@ -1376,29 +1376,6 @@ def test_resolve_engine(hass: HomeAssistant, setup: str, engine_id: str) -> None
|
|||
assert tts.async_resolve_engine(hass, None) is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("setup", "engine_id"),
|
||||
[
|
||||
("mock_setup", "test"),
|
||||
("mock_config_entry_setup", "tts.test"),
|
||||
],
|
||||
indirect=["setup"],
|
||||
)
|
||||
async def test_support_options(hass: HomeAssistant, setup: str, engine_id: str) -> None:
|
||||
"""Test supporting options."""
|
||||
assert await tts.async_support_options(hass, engine_id, "en_US") is True
|
||||
assert await tts.async_support_options(hass, engine_id, "nl") is False
|
||||
assert (
|
||||
await tts.async_support_options(
|
||||
hass, engine_id, "en_US", {"invalid_option": "yo"}
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
with pytest.raises(HomeAssistantError):
|
||||
await tts.async_support_options(hass, "non-existing")
|
||||
|
||||
|
||||
async def test_legacy_fetching_in_async(
|
||||
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||
) -> None:
|
||||
|
|
Loading…
Reference in New Issue