Abort wake word detection when assist pipeline is modified (#100918)
parent
9254eea9e2
commit
a9bcfe5700
|
@ -22,6 +22,14 @@ class WakeWordDetectionError(PipelineError):
|
|||
"""Error in wake-word-detection portion of pipeline."""
|
||||
|
||||
|
||||
class WakeWordDetectionAborted(WakeWordDetectionError):
|
||||
"""Wake-word-detection was aborted."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Set error message."""
|
||||
super().__init__("wake_word_detection_aborted", "")
|
||||
|
||||
|
||||
class WakeWordTimeoutError(WakeWordDetectionError):
|
||||
"""Timeout when wake word was not detected."""
|
||||
|
||||
|
|
|
@ -32,6 +32,7 @@ from homeassistant.components.tts.media_source import (
|
|||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers.collection import (
|
||||
CHANGE_UPDATED,
|
||||
CollectionError,
|
||||
ItemNotFound,
|
||||
SerializedStorageCollection,
|
||||
|
@ -54,6 +55,7 @@ from .error import (
|
|||
PipelineNotFound,
|
||||
SpeechToTextError,
|
||||
TextToSpeechError,
|
||||
WakeWordDetectionAborted,
|
||||
WakeWordDetectionError,
|
||||
WakeWordTimeoutError,
|
||||
)
|
||||
|
@ -470,11 +472,13 @@ class PipelineRun:
|
|||
audio_settings: AudioSettings = field(default_factory=AudioSettings)
|
||||
|
||||
id: str = field(default_factory=ulid_util.ulid)
|
||||
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False)
|
||||
tts_engine: str = field(init=False)
|
||||
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False, repr=False)
|
||||
tts_engine: str = field(init=False, repr=False)
|
||||
tts_options: dict | None = field(init=False, default=None)
|
||||
wake_word_entity_id: str = field(init=False)
|
||||
wake_word_entity: wake_word.WakeWordDetectionEntity = field(init=False)
|
||||
wake_word_entity_id: str = field(init=False, repr=False)
|
||||
wake_word_entity: wake_word.WakeWordDetectionEntity = field(init=False, repr=False)
|
||||
|
||||
abort_wake_word_detection: bool = field(init=False, default=False)
|
||||
|
||||
debug_recording_thread: Thread | None = None
|
||||
"""Thread that records audio to debug_recording_dir"""
|
||||
|
@ -485,7 +489,7 @@ class PipelineRun:
|
|||
audio_processor: AudioProcessor | None = None
|
||||
"""VAD/noise suppression/auto gain"""
|
||||
|
||||
audio_processor_buffer: AudioBuffer = field(init=False)
|
||||
audio_processor_buffer: AudioBuffer = field(init=False, repr=False)
|
||||
"""Buffer used when splitting audio into chunks for audio processing"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
|
@ -504,6 +508,7 @@ class PipelineRun:
|
|||
size_limit=STORED_PIPELINE_RUNS
|
||||
)
|
||||
pipeline_data.pipeline_debug[self.pipeline.id][self.id] = PipelineRunDebug()
|
||||
pipeline_data.pipeline_runs.add_run(self)
|
||||
|
||||
# Initialize with audio settings
|
||||
self.audio_processor_buffer = AudioBuffer(AUDIO_PROCESSOR_BYTES)
|
||||
|
@ -548,6 +553,9 @@ class PipelineRun:
|
|||
)
|
||||
)
|
||||
|
||||
pipeline_data: PipelineData = self.hass.data[DOMAIN]
|
||||
pipeline_data.pipeline_runs.remove_run(self)
|
||||
|
||||
async def prepare_wake_word_detection(self) -> None:
|
||||
"""Prepare wake-word-detection."""
|
||||
entity_id = self.pipeline.wake_word_entity or wake_word.async_default_entity(
|
||||
|
@ -638,6 +646,8 @@ class PipelineRun:
|
|||
# All audio kept from right before the wake word was detected as
|
||||
# a single chunk.
|
||||
audio_chunks_for_stt.extend(stt_audio_buffer)
|
||||
except WakeWordDetectionAborted:
|
||||
raise
|
||||
except WakeWordTimeoutError:
|
||||
_LOGGER.debug("Timeout during wake word detection")
|
||||
raise
|
||||
|
@ -696,6 +706,9 @@ class PipelineRun:
|
|||
"""
|
||||
chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate
|
||||
async for chunk in audio_stream:
|
||||
if self.abort_wake_word_detection:
|
||||
raise WakeWordDetectionAborted
|
||||
|
||||
if self.debug_recording_queue is not None:
|
||||
self.debug_recording_queue.put_nowait(chunk.audio)
|
||||
|
||||
|
@ -1547,13 +1560,48 @@ class PipelineStorageCollectionWebsocket(
|
|||
connection.send_result(msg["id"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineRuns:
|
||||
"""Class managing pipelineruns."""
|
||||
|
||||
def __init__(self, pipeline_store: PipelineStorageCollection) -> None:
|
||||
"""Initialize."""
|
||||
self._pipeline_runs: dict[str, list[PipelineRun]] = {}
|
||||
self._pipeline_store = pipeline_store
|
||||
pipeline_store.async_add_listener(self._change_listener)
|
||||
|
||||
def add_run(self, pipeline_run: PipelineRun) -> None:
|
||||
"""Add pipeline run."""
|
||||
pipeline_id = pipeline_run.pipeline.id
|
||||
if pipeline_id not in self._pipeline_runs:
|
||||
self._pipeline_runs[pipeline_id] = []
|
||||
self._pipeline_runs[pipeline_id].append(pipeline_run)
|
||||
|
||||
def remove_run(self, pipeline_run: PipelineRun) -> None:
|
||||
"""Remove pipeline run."""
|
||||
pipeline_id = pipeline_run.pipeline.id
|
||||
self._pipeline_runs[pipeline_id].remove(pipeline_run)
|
||||
|
||||
async def _change_listener(
|
||||
self, change_type: str, item_id: str, change: dict
|
||||
) -> None:
|
||||
"""Handle pipeline store changes."""
|
||||
if change_type != CHANGE_UPDATED:
|
||||
return
|
||||
if pipeline_runs := self._pipeline_runs.get(item_id):
|
||||
# Create a temporary list in case the list is modified while we iterate
|
||||
for pipeline_run in list(pipeline_runs):
|
||||
pipeline_run.abort_wake_word_detection = True
|
||||
|
||||
|
||||
class PipelineData:
|
||||
"""Store and debug data stored in hass.data."""
|
||||
|
||||
pipeline_debug: dict[str, LimitedSizeDict[str, PipelineRunDebug]]
|
||||
pipeline_store: PipelineStorageCollection
|
||||
pipeline_devices: set[str] = field(default_factory=set, init=False)
|
||||
def __init__(self, pipeline_store: PipelineStorageCollection) -> None:
|
||||
"""Initialize."""
|
||||
self.pipeline_store = pipeline_store
|
||||
self.pipeline_debug: dict[str, LimitedSizeDict[str, PipelineRunDebug]] = {}
|
||||
self.pipeline_devices: set[str] = set()
|
||||
self.pipeline_runs = PipelineRuns(pipeline_store)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -1605,4 +1653,4 @@ async def async_setup_pipeline_store(hass: HomeAssistant) -> PipelineData:
|
|||
PIPELINE_FIELDS,
|
||||
PIPELINE_FIELDS,
|
||||
).async_setup(hass)
|
||||
return PipelineData({}, pipeline_store)
|
||||
return PipelineData(pipeline_store)
|
||||
|
|
|
@ -377,3 +377,38 @@
|
|||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_wake_word_detection_aborted
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'entity_id': 'wake_word.test',
|
||||
'metadata': dict({
|
||||
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
|
||||
'channel': <AudioChannels.CHANNEL_MONO: 1>,
|
||||
'codec': <AudioCodecs.PCM: 'pcm'>,
|
||||
'format': <AudioFormats.WAV: 'wav'>,
|
||||
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'code': 'wake_word_detection_aborted',
|
||||
'message': '',
|
||||
}),
|
||||
'type': <PipelineEventType.ERROR: 'error'>,
|
||||
}),
|
||||
dict({
|
||||
'data': None,
|
||||
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
|
|
|
@ -563,3 +563,67 @@ async def test_pipeline_saved_audio_write_error(
|
|||
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
end_stage=assist_pipeline.PipelineStage.STT,
|
||||
)
|
||||
|
||||
|
||||
async def test_wake_word_detection_aborted(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSttProvider,
|
||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||
init_components,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test creating a pipeline from an audio stream with wake word."""
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
async def audio_data():
|
||||
yield b"silence!"
|
||||
yield b"wake word!"
|
||||
yield b"part1"
|
||||
yield b"part2"
|
||||
yield b""
|
||||
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
conversation_id=None,
|
||||
device_id=None,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_data(),
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=events.append,
|
||||
tts_audio_output=None,
|
||||
wake_word_settings=assist_pipeline.WakeWordSettings(
|
||||
audio_seconds_to_buffer=1.5
|
||||
),
|
||||
audio_settings=assist_pipeline.AudioSettings(
|
||||
is_vad_enabled=False, is_chunking_enabled=False
|
||||
),
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
|
||||
updates = pipeline.to_json()
|
||||
updates.pop("id")
|
||||
await pipeline_store.async_update_item(
|
||||
pipeline_id,
|
||||
updates,
|
||||
)
|
||||
await pipeline_input.execute()
|
||||
|
||||
assert process_events(events) == snapshot
|
||||
|
|
Loading…
Reference in New Issue