core/homeassistant/components/assist_pipeline/vad.py

305 lines
9.9 KiB
Python

"""Voice activity detection."""
from __future__ import annotations
from collections.abc import Iterable
from dataclasses import dataclass, field
from enum import StrEnum
from typing import Final
import webrtcvad
_SAMPLE_RATE: Final = 16000 # Hz
_SAMPLE_WIDTH: Final = 2 # bytes
class VadSensitivity(StrEnum):
"""How quickly the end of a voice command is detected."""
DEFAULT = "default"
RELAXED = "relaxed"
AGGRESSIVE = "aggressive"
@staticmethod
def to_seconds(sensitivity: VadSensitivity | str) -> float:
"""Return seconds of silence for sensitivity level."""
sensitivity = VadSensitivity(sensitivity)
if sensitivity == VadSensitivity.RELAXED:
return 2.0
if sensitivity == VadSensitivity.AGGRESSIVE:
return 0.5
return 1.0
class AudioBuffer:
"""Fixed-sized audio buffer with variable internal length."""
def __init__(self, maxlen: int) -> None:
"""Initialize buffer."""
self._buffer = bytearray(maxlen)
self._length = 0
@property
def length(self) -> int:
"""Get number of bytes currently in the buffer."""
return self._length
def clear(self) -> None:
"""Clear the buffer."""
self._length = 0
def append(self, data: bytes) -> None:
"""Append bytes to the buffer, increasing the internal length."""
data_len = len(data)
if (self._length + data_len) > len(self._buffer):
raise ValueError("Length cannot be greater than buffer size")
self._buffer[self._length : self._length + data_len] = data
self._length += data_len
def bytes(self) -> bytes:
"""Convert written portion of buffer to bytes."""
return bytes(self._buffer[: self._length])
def __len__(self) -> int:
"""Get the number of bytes currently in the buffer."""
return self._length
def __bool__(self) -> bool:
"""Return True if there are bytes in the buffer."""
return self._length > 0
@dataclass
class VoiceCommandSegmenter:
"""Segments an audio stream into voice commands using webrtcvad."""
vad_mode: int = 3
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
vad_samples_per_chunk: int = 480 # 30 ms
"""Must be 10, 20, or 30 ms at 16Khz."""
speech_seconds: float = 0.3
"""Seconds of speech before voice command has started."""
silence_seconds: float = 0.5
"""Seconds of silence after voice command has ended."""
timeout_seconds: float = 15.0
"""Maximum number of seconds before stopping with timeout=True."""
reset_seconds: float = 1.0
"""Seconds before reset start/stop time counters."""
in_command: bool = False
"""True if inside voice command."""
_speech_seconds_left: float = 0.0
"""Seconds left before considering voice command as started."""
_silence_seconds_left: float = 0.0
"""Seconds left before considering voice command as stopped."""
_timeout_seconds_left: float = 0.0
"""Seconds left before considering voice command timed out."""
_reset_seconds_left: float = 0.0
"""Seconds left before resetting start/stop time counters."""
_vad: webrtcvad.Vad = None
_leftover_chunk_buffer: AudioBuffer = field(init=False)
_bytes_per_chunk: int = field(init=False)
_seconds_per_chunk: float = field(init=False)
def __post_init__(self) -> None:
"""Initialize VAD."""
self._vad = webrtcvad.Vad(self.vad_mode)
self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH
self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE
self._leftover_chunk_buffer = AudioBuffer(
self.vad_samples_per_chunk * _SAMPLE_WIDTH
)
self.reset()
def reset(self) -> None:
"""Reset all counters and state."""
self._leftover_chunk_buffer.clear()
self._speech_seconds_left = self.speech_seconds
self._silence_seconds_left = self.silence_seconds
self._timeout_seconds_left = self.timeout_seconds
self._reset_seconds_left = self.reset_seconds
self.in_command = False
def process(self, samples: bytes) -> bool:
"""Process 16-bit 16Khz mono audio samples.
Returns False when command is done.
"""
for chunk in chunk_samples(
samples, self._bytes_per_chunk, self._leftover_chunk_buffer
):
if not self._process_chunk(chunk):
self.reset()
return False
return True
@property
def audio_buffer(self) -> bytes:
"""Get partial chunk in the audio buffer."""
return self._leftover_chunk_buffer.bytes()
def _process_chunk(self, chunk: bytes) -> bool:
"""Process a single chunk of 16-bit 16Khz mono audio.
Returns False when command is done.
"""
is_speech = self._vad.is_speech(chunk, _SAMPLE_RATE)
self._timeout_seconds_left -= self._seconds_per_chunk
if self._timeout_seconds_left <= 0:
return False
if not self.in_command:
if is_speech:
self._reset_seconds_left = self.reset_seconds
self._speech_seconds_left -= self._seconds_per_chunk
if self._speech_seconds_left <= 0:
# Inside voice command
self.in_command = True
else:
# Reset if enough silence
self._reset_seconds_left -= self._seconds_per_chunk
if self._reset_seconds_left <= 0:
self._speech_seconds_left = self.speech_seconds
elif not is_speech:
self._reset_seconds_left = self.reset_seconds
self._silence_seconds_left -= self._seconds_per_chunk
if self._silence_seconds_left <= 0:
return False
else:
# Reset if enough speech
self._reset_seconds_left -= self._seconds_per_chunk
if self._reset_seconds_left <= 0:
self._silence_seconds_left = self.silence_seconds
return True
@dataclass
class VoiceActivityTimeout:
"""Detects silence in audio until a timeout is reached."""
silence_seconds: float
"""Seconds of silence before timeout."""
reset_seconds: float = 0.5
"""Seconds of speech before resetting timeout."""
vad_mode: int = 3
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
vad_samples_per_chunk: int = 480 # 30 ms
"""Must be 10, 20, or 30 ms at 16Khz."""
_silence_seconds_left: float = 0.0
"""Seconds left before considering voice command as stopped."""
_reset_seconds_left: float = 0.0
"""Seconds left before resetting start/stop time counters."""
_vad: webrtcvad.Vad = None
_leftover_chunk_buffer: AudioBuffer = field(init=False)
_bytes_per_chunk: int = field(init=False)
_seconds_per_chunk: float = field(init=False)
def __post_init__(self) -> None:
"""Initialize VAD."""
self._vad = webrtcvad.Vad(self.vad_mode)
self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH
self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE
self._leftover_chunk_buffer = AudioBuffer(
self.vad_samples_per_chunk * _SAMPLE_WIDTH
)
self.reset()
def reset(self) -> None:
"""Reset all counters and state."""
self._leftover_chunk_buffer.clear()
self._silence_seconds_left = self.silence_seconds
self._reset_seconds_left = self.reset_seconds
def process(self, samples: bytes) -> bool:
"""Process 16-bit 16Khz mono audio samples.
Returns False when timeout is reached.
"""
for chunk in chunk_samples(
samples, self._bytes_per_chunk, self._leftover_chunk_buffer
):
if not self._process_chunk(chunk):
return False
return True
def _process_chunk(self, chunk: bytes) -> bool:
"""Process a single chunk of 16-bit 16Khz mono audio.
Returns False when timeout is reached.
"""
if self._vad.is_speech(chunk, _SAMPLE_RATE):
# Speech
self._reset_seconds_left -= self._seconds_per_chunk
if self._reset_seconds_left <= 0:
# Reset timeout
self._silence_seconds_left = self.silence_seconds
else:
# Silence
self._silence_seconds_left -= self._seconds_per_chunk
if self._silence_seconds_left <= 0:
# Timeout reached
return False
# Slowly build reset counter back up
self._reset_seconds_left = min(
self.reset_seconds, self._reset_seconds_left + self._seconds_per_chunk
)
return True
def chunk_samples(
samples: bytes,
bytes_per_chunk: int,
leftover_chunk_buffer: AudioBuffer,
) -> Iterable[bytes]:
"""Yield fixed-sized chunks from samples, keeping leftover bytes from previous call(s)."""
if (len(leftover_chunk_buffer) + len(samples)) < bytes_per_chunk:
# Extend leftover chunk, but not enough samples to complete it
leftover_chunk_buffer.append(samples)
return
next_chunk_idx = 0
if leftover_chunk_buffer:
# Add to leftover chunk from previous call(s).
bytes_to_copy = bytes_per_chunk - len(leftover_chunk_buffer)
leftover_chunk_buffer.append(samples[:bytes_to_copy])
next_chunk_idx = bytes_to_copy
# Process full chunk in buffer
yield leftover_chunk_buffer.bytes()
leftover_chunk_buffer.clear()
while next_chunk_idx < len(samples) - bytes_per_chunk + 1:
# Process full chunk
yield samples[next_chunk_idx : next_chunk_idx + bytes_per_chunk]
next_chunk_idx += bytes_per_chunk
# Capture leftover chunks
if rest_samples := samples[next_chunk_idx:]:
leftover_chunk_buffer.append(rest_samples)