305 lines
9.9 KiB
Python
305 lines
9.9 KiB
Python
"""Voice activity detection."""
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Iterable
|
|
from dataclasses import dataclass, field
|
|
from enum import StrEnum
|
|
from typing import Final
|
|
|
|
import webrtcvad
|
|
|
|
_SAMPLE_RATE: Final = 16000 # Hz
|
|
_SAMPLE_WIDTH: Final = 2 # bytes
|
|
|
|
|
|
class VadSensitivity(StrEnum):
|
|
"""How quickly the end of a voice command is detected."""
|
|
|
|
DEFAULT = "default"
|
|
RELAXED = "relaxed"
|
|
AGGRESSIVE = "aggressive"
|
|
|
|
@staticmethod
|
|
def to_seconds(sensitivity: VadSensitivity | str) -> float:
|
|
"""Return seconds of silence for sensitivity level."""
|
|
sensitivity = VadSensitivity(sensitivity)
|
|
if sensitivity == VadSensitivity.RELAXED:
|
|
return 2.0
|
|
|
|
if sensitivity == VadSensitivity.AGGRESSIVE:
|
|
return 0.5
|
|
|
|
return 1.0
|
|
|
|
|
|
class AudioBuffer:
|
|
"""Fixed-sized audio buffer with variable internal length."""
|
|
|
|
def __init__(self, maxlen: int) -> None:
|
|
"""Initialize buffer."""
|
|
self._buffer = bytearray(maxlen)
|
|
self._length = 0
|
|
|
|
@property
|
|
def length(self) -> int:
|
|
"""Get number of bytes currently in the buffer."""
|
|
return self._length
|
|
|
|
def clear(self) -> None:
|
|
"""Clear the buffer."""
|
|
self._length = 0
|
|
|
|
def append(self, data: bytes) -> None:
|
|
"""Append bytes to the buffer, increasing the internal length."""
|
|
data_len = len(data)
|
|
if (self._length + data_len) > len(self._buffer):
|
|
raise ValueError("Length cannot be greater than buffer size")
|
|
|
|
self._buffer[self._length : self._length + data_len] = data
|
|
self._length += data_len
|
|
|
|
def bytes(self) -> bytes:
|
|
"""Convert written portion of buffer to bytes."""
|
|
return bytes(self._buffer[: self._length])
|
|
|
|
def __len__(self) -> int:
|
|
"""Get the number of bytes currently in the buffer."""
|
|
return self._length
|
|
|
|
def __bool__(self) -> bool:
|
|
"""Return True if there are bytes in the buffer."""
|
|
return self._length > 0
|
|
|
|
|
|
@dataclass
|
|
class VoiceCommandSegmenter:
|
|
"""Segments an audio stream into voice commands using webrtcvad."""
|
|
|
|
vad_mode: int = 3
|
|
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
|
|
|
|
vad_samples_per_chunk: int = 480 # 30 ms
|
|
"""Must be 10, 20, or 30 ms at 16Khz."""
|
|
|
|
speech_seconds: float = 0.3
|
|
"""Seconds of speech before voice command has started."""
|
|
|
|
silence_seconds: float = 0.5
|
|
"""Seconds of silence after voice command has ended."""
|
|
|
|
timeout_seconds: float = 15.0
|
|
"""Maximum number of seconds before stopping with timeout=True."""
|
|
|
|
reset_seconds: float = 1.0
|
|
"""Seconds before reset start/stop time counters."""
|
|
|
|
in_command: bool = False
|
|
"""True if inside voice command."""
|
|
|
|
_speech_seconds_left: float = 0.0
|
|
"""Seconds left before considering voice command as started."""
|
|
|
|
_silence_seconds_left: float = 0.0
|
|
"""Seconds left before considering voice command as stopped."""
|
|
|
|
_timeout_seconds_left: float = 0.0
|
|
"""Seconds left before considering voice command timed out."""
|
|
|
|
_reset_seconds_left: float = 0.0
|
|
"""Seconds left before resetting start/stop time counters."""
|
|
|
|
_vad: webrtcvad.Vad = None
|
|
_leftover_chunk_buffer: AudioBuffer = field(init=False)
|
|
_bytes_per_chunk: int = field(init=False)
|
|
_seconds_per_chunk: float = field(init=False)
|
|
|
|
def __post_init__(self) -> None:
|
|
"""Initialize VAD."""
|
|
self._vad = webrtcvad.Vad(self.vad_mode)
|
|
self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH
|
|
self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE
|
|
self._leftover_chunk_buffer = AudioBuffer(
|
|
self.vad_samples_per_chunk * _SAMPLE_WIDTH
|
|
)
|
|
self.reset()
|
|
|
|
def reset(self) -> None:
|
|
"""Reset all counters and state."""
|
|
self._leftover_chunk_buffer.clear()
|
|
self._speech_seconds_left = self.speech_seconds
|
|
self._silence_seconds_left = self.silence_seconds
|
|
self._timeout_seconds_left = self.timeout_seconds
|
|
self._reset_seconds_left = self.reset_seconds
|
|
self.in_command = False
|
|
|
|
def process(self, samples: bytes) -> bool:
|
|
"""Process 16-bit 16Khz mono audio samples.
|
|
|
|
Returns False when command is done.
|
|
"""
|
|
for chunk in chunk_samples(
|
|
samples, self._bytes_per_chunk, self._leftover_chunk_buffer
|
|
):
|
|
if not self._process_chunk(chunk):
|
|
self.reset()
|
|
return False
|
|
|
|
return True
|
|
|
|
@property
|
|
def audio_buffer(self) -> bytes:
|
|
"""Get partial chunk in the audio buffer."""
|
|
return self._leftover_chunk_buffer.bytes()
|
|
|
|
def _process_chunk(self, chunk: bytes) -> bool:
|
|
"""Process a single chunk of 16-bit 16Khz mono audio.
|
|
|
|
Returns False when command is done.
|
|
"""
|
|
is_speech = self._vad.is_speech(chunk, _SAMPLE_RATE)
|
|
|
|
self._timeout_seconds_left -= self._seconds_per_chunk
|
|
if self._timeout_seconds_left <= 0:
|
|
return False
|
|
|
|
if not self.in_command:
|
|
if is_speech:
|
|
self._reset_seconds_left = self.reset_seconds
|
|
self._speech_seconds_left -= self._seconds_per_chunk
|
|
if self._speech_seconds_left <= 0:
|
|
# Inside voice command
|
|
self.in_command = True
|
|
else:
|
|
# Reset if enough silence
|
|
self._reset_seconds_left -= self._seconds_per_chunk
|
|
if self._reset_seconds_left <= 0:
|
|
self._speech_seconds_left = self.speech_seconds
|
|
elif not is_speech:
|
|
self._reset_seconds_left = self.reset_seconds
|
|
self._silence_seconds_left -= self._seconds_per_chunk
|
|
if self._silence_seconds_left <= 0:
|
|
return False
|
|
else:
|
|
# Reset if enough speech
|
|
self._reset_seconds_left -= self._seconds_per_chunk
|
|
if self._reset_seconds_left <= 0:
|
|
self._silence_seconds_left = self.silence_seconds
|
|
|
|
return True
|
|
|
|
|
|
@dataclass
|
|
class VoiceActivityTimeout:
|
|
"""Detects silence in audio until a timeout is reached."""
|
|
|
|
silence_seconds: float
|
|
"""Seconds of silence before timeout."""
|
|
|
|
reset_seconds: float = 0.5
|
|
"""Seconds of speech before resetting timeout."""
|
|
|
|
vad_mode: int = 3
|
|
"""Aggressiveness in filtering out non-speech. 3 is the most aggressive."""
|
|
|
|
vad_samples_per_chunk: int = 480 # 30 ms
|
|
"""Must be 10, 20, or 30 ms at 16Khz."""
|
|
|
|
_silence_seconds_left: float = 0.0
|
|
"""Seconds left before considering voice command as stopped."""
|
|
|
|
_reset_seconds_left: float = 0.0
|
|
"""Seconds left before resetting start/stop time counters."""
|
|
|
|
_vad: webrtcvad.Vad = None
|
|
_leftover_chunk_buffer: AudioBuffer = field(init=False)
|
|
_bytes_per_chunk: int = field(init=False)
|
|
_seconds_per_chunk: float = field(init=False)
|
|
|
|
def __post_init__(self) -> None:
|
|
"""Initialize VAD."""
|
|
self._vad = webrtcvad.Vad(self.vad_mode)
|
|
self._bytes_per_chunk = self.vad_samples_per_chunk * _SAMPLE_WIDTH
|
|
self._seconds_per_chunk = self.vad_samples_per_chunk / _SAMPLE_RATE
|
|
self._leftover_chunk_buffer = AudioBuffer(
|
|
self.vad_samples_per_chunk * _SAMPLE_WIDTH
|
|
)
|
|
self.reset()
|
|
|
|
def reset(self) -> None:
|
|
"""Reset all counters and state."""
|
|
self._leftover_chunk_buffer.clear()
|
|
self._silence_seconds_left = self.silence_seconds
|
|
self._reset_seconds_left = self.reset_seconds
|
|
|
|
def process(self, samples: bytes) -> bool:
|
|
"""Process 16-bit 16Khz mono audio samples.
|
|
|
|
Returns False when timeout is reached.
|
|
"""
|
|
for chunk in chunk_samples(
|
|
samples, self._bytes_per_chunk, self._leftover_chunk_buffer
|
|
):
|
|
if not self._process_chunk(chunk):
|
|
return False
|
|
|
|
return True
|
|
|
|
def _process_chunk(self, chunk: bytes) -> bool:
|
|
"""Process a single chunk of 16-bit 16Khz mono audio.
|
|
|
|
Returns False when timeout is reached.
|
|
"""
|
|
if self._vad.is_speech(chunk, _SAMPLE_RATE):
|
|
# Speech
|
|
self._reset_seconds_left -= self._seconds_per_chunk
|
|
if self._reset_seconds_left <= 0:
|
|
# Reset timeout
|
|
self._silence_seconds_left = self.silence_seconds
|
|
else:
|
|
# Silence
|
|
self._silence_seconds_left -= self._seconds_per_chunk
|
|
if self._silence_seconds_left <= 0:
|
|
# Timeout reached
|
|
return False
|
|
|
|
# Slowly build reset counter back up
|
|
self._reset_seconds_left = min(
|
|
self.reset_seconds, self._reset_seconds_left + self._seconds_per_chunk
|
|
)
|
|
|
|
return True
|
|
|
|
|
|
def chunk_samples(
|
|
samples: bytes,
|
|
bytes_per_chunk: int,
|
|
leftover_chunk_buffer: AudioBuffer,
|
|
) -> Iterable[bytes]:
|
|
"""Yield fixed-sized chunks from samples, keeping leftover bytes from previous call(s)."""
|
|
|
|
if (len(leftover_chunk_buffer) + len(samples)) < bytes_per_chunk:
|
|
# Extend leftover chunk, but not enough samples to complete it
|
|
leftover_chunk_buffer.append(samples)
|
|
return
|
|
|
|
next_chunk_idx = 0
|
|
|
|
if leftover_chunk_buffer:
|
|
# Add to leftover chunk from previous call(s).
|
|
bytes_to_copy = bytes_per_chunk - len(leftover_chunk_buffer)
|
|
leftover_chunk_buffer.append(samples[:bytes_to_copy])
|
|
next_chunk_idx = bytes_to_copy
|
|
|
|
# Process full chunk in buffer
|
|
yield leftover_chunk_buffer.bytes()
|
|
leftover_chunk_buffer.clear()
|
|
|
|
while next_chunk_idx < len(samples) - bytes_per_chunk + 1:
|
|
# Process full chunk
|
|
yield samples[next_chunk_idx : next_chunk_idx + bytes_per_chunk]
|
|
next_chunk_idx += bytes_per_chunk
|
|
|
|
# Capture leftover chunks
|
|
if rest_samples := samples[next_chunk_idx:]:
|
|
leftover_chunk_buffer.append(rest_samples)
|