Abort wake word detection when assist pipeline is modified (#100918)

pull/100952/head
Erik Montnemery 2023-09-26 20:24:55 +02:00 committed by GitHub
parent 9254eea9e2
commit a9bcfe5700
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 165 additions and 10 deletions

View File

@ -22,6 +22,14 @@ class WakeWordDetectionError(PipelineError):
"""Error in wake-word-detection portion of pipeline."""
class WakeWordDetectionAborted(WakeWordDetectionError):
"""Wake-word-detection was aborted."""
def __init__(self) -> None:
"""Set error message."""
super().__init__("wake_word_detection_aborted", "")
class WakeWordTimeoutError(WakeWordDetectionError):
"""Timeout when wake word was not detected."""

View File

@ -32,6 +32,7 @@ from homeassistant.components.tts.media_source import (
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.collection import (
CHANGE_UPDATED,
CollectionError,
ItemNotFound,
SerializedStorageCollection,
@ -54,6 +55,7 @@ from .error import (
PipelineNotFound,
SpeechToTextError,
TextToSpeechError,
WakeWordDetectionAborted,
WakeWordDetectionError,
WakeWordTimeoutError,
)
@ -470,11 +472,13 @@ class PipelineRun:
audio_settings: AudioSettings = field(default_factory=AudioSettings)
id: str = field(default_factory=ulid_util.ulid)
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False)
tts_engine: str = field(init=False)
stt_provider: stt.SpeechToTextEntity | stt.Provider = field(init=False, repr=False)
tts_engine: str = field(init=False, repr=False)
tts_options: dict | None = field(init=False, default=None)
wake_word_entity_id: str = field(init=False)
wake_word_entity: wake_word.WakeWordDetectionEntity = field(init=False)
wake_word_entity_id: str = field(init=False, repr=False)
wake_word_entity: wake_word.WakeWordDetectionEntity = field(init=False, repr=False)
abort_wake_word_detection: bool = field(init=False, default=False)
debug_recording_thread: Thread | None = None
"""Thread that records audio to debug_recording_dir"""
@ -485,7 +489,7 @@ class PipelineRun:
audio_processor: AudioProcessor | None = None
"""VAD/noise suppression/auto gain"""
audio_processor_buffer: AudioBuffer = field(init=False)
audio_processor_buffer: AudioBuffer = field(init=False, repr=False)
"""Buffer used when splitting audio into chunks for audio processing"""
def __post_init__(self) -> None:
@ -504,6 +508,7 @@ class PipelineRun:
size_limit=STORED_PIPELINE_RUNS
)
pipeline_data.pipeline_debug[self.pipeline.id][self.id] = PipelineRunDebug()
pipeline_data.pipeline_runs.add_run(self)
# Initialize with audio settings
self.audio_processor_buffer = AudioBuffer(AUDIO_PROCESSOR_BYTES)
@ -548,6 +553,9 @@ class PipelineRun:
)
)
pipeline_data: PipelineData = self.hass.data[DOMAIN]
pipeline_data.pipeline_runs.remove_run(self)
async def prepare_wake_word_detection(self) -> None:
"""Prepare wake-word-detection."""
entity_id = self.pipeline.wake_word_entity or wake_word.async_default_entity(
@ -638,6 +646,8 @@ class PipelineRun:
# All audio kept from right before the wake word was detected as
# a single chunk.
audio_chunks_for_stt.extend(stt_audio_buffer)
except WakeWordDetectionAborted:
raise
except WakeWordTimeoutError:
_LOGGER.debug("Timeout during wake word detection")
raise
@ -696,6 +706,9 @@ class PipelineRun:
"""
chunk_seconds = AUDIO_PROCESSOR_SAMPLES / sample_rate
async for chunk in audio_stream:
if self.abort_wake_word_detection:
raise WakeWordDetectionAborted
if self.debug_recording_queue is not None:
self.debug_recording_queue.put_nowait(chunk.audio)
@ -1547,13 +1560,48 @@ class PipelineStorageCollectionWebsocket(
connection.send_result(msg["id"])
@dataclass
class PipelineRuns:
"""Class managing pipelineruns."""
def __init__(self, pipeline_store: PipelineStorageCollection) -> None:
"""Initialize."""
self._pipeline_runs: dict[str, list[PipelineRun]] = {}
self._pipeline_store = pipeline_store
pipeline_store.async_add_listener(self._change_listener)
def add_run(self, pipeline_run: PipelineRun) -> None:
"""Add pipeline run."""
pipeline_id = pipeline_run.pipeline.id
if pipeline_id not in self._pipeline_runs:
self._pipeline_runs[pipeline_id] = []
self._pipeline_runs[pipeline_id].append(pipeline_run)
def remove_run(self, pipeline_run: PipelineRun) -> None:
"""Remove pipeline run."""
pipeline_id = pipeline_run.pipeline.id
self._pipeline_runs[pipeline_id].remove(pipeline_run)
async def _change_listener(
self, change_type: str, item_id: str, change: dict
) -> None:
"""Handle pipeline store changes."""
if change_type != CHANGE_UPDATED:
return
if pipeline_runs := self._pipeline_runs.get(item_id):
# Create a temporary list in case the list is modified while we iterate
for pipeline_run in list(pipeline_runs):
pipeline_run.abort_wake_word_detection = True
class PipelineData:
"""Store and debug data stored in hass.data."""
pipeline_debug: dict[str, LimitedSizeDict[str, PipelineRunDebug]]
pipeline_store: PipelineStorageCollection
pipeline_devices: set[str] = field(default_factory=set, init=False)
def __init__(self, pipeline_store: PipelineStorageCollection) -> None:
"""Initialize."""
self.pipeline_store = pipeline_store
self.pipeline_debug: dict[str, LimitedSizeDict[str, PipelineRunDebug]] = {}
self.pipeline_devices: set[str] = set()
self.pipeline_runs = PipelineRuns(pipeline_store)
@dataclass
@ -1605,4 +1653,4 @@ async def async_setup_pipeline_store(hass: HomeAssistant) -> PipelineData:
PIPELINE_FIELDS,
PIPELINE_FIELDS,
).async_setup(hass)
return PipelineData({}, pipeline_store)
return PipelineData(pipeline_store)

View File

@ -377,3 +377,38 @@
}),
])
# ---
# name: test_wake_word_detection_aborted
list([
dict({
'data': dict({
'language': 'en',
'pipeline': <ANY>,
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
dict({
'data': dict({
'entity_id': 'wake_word.test',
'metadata': dict({
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
'channel': <AudioChannels.CHANNEL_MONO: 1>,
'codec': <AudioCodecs.PCM: 'pcm'>,
'format': <AudioFormats.WAV: 'wav'>,
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
}),
}),
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
}),
dict({
'data': dict({
'code': 'wake_word_detection_aborted',
'message': '',
}),
'type': <PipelineEventType.ERROR: 'error'>,
}),
dict({
'data': None,
'type': <PipelineEventType.RUN_END: 'run-end'>,
}),
])
# ---

View File

@ -563,3 +563,67 @@ async def test_pipeline_saved_audio_write_error(
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
end_stage=assist_pipeline.PipelineStage.STT,
)
async def test_wake_word_detection_aborted(
hass: HomeAssistant,
mock_stt_provider: MockSttProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
snapshot: SnapshotAssertion,
) -> None:
"""Test creating a pipeline from an audio stream with wake word."""
events: list[assist_pipeline.PipelineEvent] = []
async def audio_data():
yield b"silence!"
yield b"wake word!"
yield b"part1"
yield b"part2"
yield b""
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
conversation_id=None,
device_id=None,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output=None,
wake_word_settings=assist_pipeline.WakeWordSettings(
audio_seconds_to_buffer=1.5
),
audio_settings=assist_pipeline.AudioSettings(
is_vad_enabled=False, is_chunking_enabled=False
),
),
)
await pipeline_input.validate()
updates = pipeline.to_json()
updates.pop("id")
await pipeline_store.async_update_item(
pipeline_id,
updates,
)
await pipeline_input.execute()
assert process_events(events) == snapshot