Cancel running pipeline on new pipeline or announcement (#125687)

* Cancel running pipeline

* Incorporate feedback

* Change to async_create_task
pull/125713/head^2
Michael Hansen 2024-09-10 19:56:15 -05:00 committed by GitHub
parent c01bdd860a
commit 8e0b2b752c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 138 additions and 22 deletions

View File

@ -3,6 +3,7 @@
from abc import abstractmethod
import asyncio
from collections.abc import AsyncIterable
import contextlib
from enum import StrEnum
import logging
import time
@ -73,6 +74,7 @@ class AssistSatelliteEntity(entity.Entity):
_is_announcing = False
_wake_word_intercept_future: asyncio.Future[str | None] | None = None
_attr_tts_options: dict[str, Any] | None = None
_pipeline_task: asyncio.Task | None = None
__assist_satellite_state = AssistSatelliteState.LISTENING_WAKE_WORD
@ -131,6 +133,8 @@ class AssistSatelliteEntity(entity.Entity):
Calls async_announce with message and media id.
"""
await self._cancel_running_pipeline()
if message is None:
message = ""
@ -176,7 +180,7 @@ class AssistSatelliteEntity(entity.Entity):
await self.async_announce(message, media_id)
finally:
self._is_announcing = False
self.tts_response_finished()
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
async def async_announce(self, message: str, media_id: str) -> None:
"""Announce media on the satellite.
@ -193,6 +197,8 @@ class AssistSatelliteEntity(entity.Entity):
wake_word_phrase: str | None = None,
) -> None:
"""Triggers an Assist pipeline in Home Assistant from a satellite."""
await self._cancel_running_pipeline()
if self._wake_word_intercept_future and start_stage in (
PipelineStage.WAKE_WORD,
PipelineStage.STT,
@ -248,31 +254,50 @@ class AssistSatelliteEntity(entity.Entity):
# Set entity state based on pipeline events
self._run_has_tts = False
await async_pipeline_from_audio_stream(
assert self.platform.config_entry is not None
self._pipeline_task = self.platform.config_entry.async_create_background_task(
self.hass,
context=self._context,
event_callback=self._internal_on_pipeline_event,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
async_pipeline_from_audio_stream(
self.hass,
context=self._context,
event_callback=self._internal_on_pipeline_event,
stt_metadata=stt.SpeechMetadata(
language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_stream,
pipeline_id=self._resolve_pipeline(),
conversation_id=self._conversation_id,
device_id=device_id,
tts_audio_output=self.tts_options,
wake_word_phrase=wake_word_phrase,
audio_settings=AudioSettings(
silence_seconds=self._resolve_vad_sensitivity()
),
start_stage=start_stage,
end_stage=end_stage,
),
stt_stream=audio_stream,
pipeline_id=self._resolve_pipeline(),
conversation_id=self._conversation_id,
device_id=device_id,
tts_audio_output=self.tts_options,
wake_word_phrase=wake_word_phrase,
audio_settings=AudioSettings(
silence_seconds=self._resolve_vad_sensitivity()
),
start_stage=start_stage,
end_stage=end_stage,
f"{self.entity_id}_pipeline",
)
try:
await self._pipeline_task
finally:
self._pipeline_task = None
async def _cancel_running_pipeline(self) -> None:
"""Cancel the current pipeline if it's running."""
if self._pipeline_task is not None:
self._pipeline_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._pipeline_task
self._pipeline_task = None
@abstractmethod
def on_pipeline_event(self, event: PipelineEvent) -> None:
"""Handle pipeline events."""

View File

@ -93,6 +93,55 @@ async def test_entity_state(
assert state.state == AssistSatelliteState.LISTENING_WAKE_WORD
async def test_new_pipeline_cancels_pipeline(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
) -> None:
"""Test that a new pipeline run cancels any running pipeline."""
pipeline1_started = asyncio.Event()
pipeline1_finished = asyncio.Event()
pipeline1_cancelled = asyncio.Event()
pipeline2_finished = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
if not pipeline1_started.is_set():
# First pipeline run
pipeline1_started.set()
# Wait for pipeline to be cancelled
try:
await pipeline1_finished.wait()
except asyncio.CancelledError:
pipeline1_cancelled.set()
raise
else:
# Second pipeline run
pipeline2_finished.set()
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
):
hass.async_create_task(
entity.async_accept_pipeline_from_satellite(
object(), # type: ignore[arg-type]
)
)
async with asyncio.timeout(1):
await pipeline1_started.wait()
# Start a second pipeline
await entity.async_accept_pipeline_from_satellite(
object(), # type: ignore[arg-type]
)
await pipeline1_cancelled.wait()
await pipeline2_finished.wait()
@pytest.mark.parametrize(
("service_data", "expected_params"),
[
@ -210,6 +259,48 @@ async def test_announce_busy(
await announce_task
async def test_announce_cancels_pipeline(
hass: HomeAssistant,
init_components: ConfigEntry,
entity: MockAssistSatellite,
) -> None:
"""Test that announcements cancel any running pipeline."""
media_id = "https://www.home-assistant.io/resolved.mp3"
pipeline_started = asyncio.Event()
pipeline_finished = asyncio.Event()
pipeline_cancelled = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
pipeline_started.set()
# Wait for pipeline to be cancelled
try:
await pipeline_finished.wait()
except asyncio.CancelledError:
pipeline_cancelled.set()
raise
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch.object(entity, "async_announce") as mock_async_announce,
):
hass.async_create_task(
entity.async_accept_pipeline_from_satellite(
object(), # type: ignore[arg-type]
)
)
async with asyncio.timeout(1):
await pipeline_started.wait()
await entity.async_internal_announce(None, media_id)
await pipeline_cancelled.wait()
mock_async_announce.assert_called_once()
async def test_context_refresh(
hass: HomeAssistant, init_components: ConfigEntry, entity: MockAssistSatellite
) -> None: