"""Support for Wyoming speech-to-text services.""" from collections.abc import AsyncIterable import logging from wyoming.asr import Transcript from wyoming.audio import AudioChunk, AudioStart, AudioStop from wyoming.client import AsyncTcpClient from homeassistant.components import stt from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback from .const import DOMAIN, SAMPLE_CHANNELS, SAMPLE_RATE, SAMPLE_WIDTH from .data import WyomingService from .error import WyomingError _LOGGER = logging.getLogger(__name__) async def async_setup_entry( hass: HomeAssistant, config_entry: ConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: """Set up Wyoming speech-to-text.""" service: WyomingService = hass.data[DOMAIN][config_entry.entry_id] async_add_entities( [ WyomingSttProvider(config_entry, service), ] ) class WyomingSttProvider(stt.SpeechToTextEntity): """Wyoming speech-to-text provider.""" def __init__( self, config_entry: ConfigEntry, service: WyomingService, ) -> None: """Set up provider.""" self.service = service asr_service = service.info.asr[0] model_languages: set[str] = set() for asr_model in asr_service.models: if asr_model.installed: model_languages.update(asr_model.languages) self._supported_languages = list(model_languages) self._attr_name = asr_service.name self._attr_unique_id = f"{config_entry.entry_id}-stt" @property def supported_languages(self) -> list[str]: """Return a list of supported languages.""" return self._supported_languages @property def supported_formats(self) -> list[stt.AudioFormats]: """Return a list of supported formats.""" return [stt.AudioFormats.WAV] @property def supported_codecs(self) -> list[stt.AudioCodecs]: """Return a list of supported codecs.""" return [stt.AudioCodecs.PCM] @property def supported_bit_rates(self) -> list[stt.AudioBitRates]: """Return a list of supported bitrates.""" return [stt.AudioBitRates.BITRATE_16] @property def supported_sample_rates(self) -> list[stt.AudioSampleRates]: """Return a list of supported samplerates.""" return [stt.AudioSampleRates.SAMPLERATE_16000] @property def supported_channels(self) -> list[stt.AudioChannels]: """Return a list of supported channels.""" return [stt.AudioChannels.CHANNEL_MONO] async def async_process_audio_stream( self, metadata: stt.SpeechMetadata, stream: AsyncIterable[bytes] ) -> stt.SpeechResult: """Process an audio stream to STT service.""" try: async with AsyncTcpClient(self.service.host, self.service.port) as client: await client.write_event( AudioStart( rate=SAMPLE_RATE, width=SAMPLE_WIDTH, channels=SAMPLE_CHANNELS, ).event(), ) async for audio_bytes in stream: chunk = AudioChunk( rate=SAMPLE_RATE, width=SAMPLE_WIDTH, channels=SAMPLE_CHANNELS, audio=audio_bytes, ) await client.write_event(chunk.event()) await client.write_event(AudioStop().event()) while True: event = await client.read_event() if event is None: _LOGGER.debug("Connection lost") return stt.SpeechResult(None, stt.SpeechResultState.ERROR) if Transcript.is_type(event.type): transcript = Transcript.from_event(event) text = transcript.text break except (OSError, WyomingError) as err: _LOGGER.exception("Error processing audio stream: %s", err) return stt.SpeechResult(None, stt.SpeechResultState.ERROR) return stt.SpeechResult( text, stt.SpeechResultState.SUCCESS, )