"""Handle legacy speech-to-text platforms.""" from __future__ import annotations from abc import ABC, abstractmethod from collections.abc import AsyncIterable, Coroutine import logging from typing import Any from homeassistant.config import config_per_platform from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import discovery from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.setup import async_prepare_setup_platform from .const import ( DATA_PROVIDERS, DOMAIN, AudioBitRates, AudioChannels, AudioCodecs, AudioFormats, AudioSampleRates, ) from .models import SpeechMetadata, SpeechResult _LOGGER = logging.getLogger(__name__) @callback def async_default_provider(hass: HomeAssistant) -> str | None: """Return the domain of the default provider.""" return next(iter(hass.data[DATA_PROVIDERS]), None) @callback def async_get_provider( hass: HomeAssistant, domain: str | None = None ) -> Provider | None: """Return provider.""" providers: dict[str, Provider] = hass.data[DATA_PROVIDERS] if domain: return providers.get(domain) provider = async_default_provider(hass) return providers[provider] if provider is not None else None @callback def async_setup_legacy( hass: HomeAssistant, config: ConfigType ) -> list[Coroutine[Any, Any, None]]: """Set up legacy speech-to-text providers.""" providers = hass.data[DATA_PROVIDERS] = {} async def async_setup_platform( p_type: str, p_config: ConfigType | None = None, discovery_info: DiscoveryInfoType | None = None, ) -> None: """Set up an STT platform.""" if p_config is None: p_config = {} platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type) if platform is None: _LOGGER.error("Unknown speech-to-text platform specified") return try: provider = await platform.async_get_engine(hass, p_config, discovery_info) provider.name = p_type provider.hass = hass providers[provider.name] = provider except Exception: # pylint: disable=broad-except _LOGGER.exception("Error setting up platform: %s", p_type) return # Add discovery support async def async_platform_discovered( platform: str, info: DiscoveryInfoType | None ) -> None: """Handle for discovered platform.""" await async_setup_platform(platform, discovery_info=info) discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered) return [ async_setup_platform(p_type, p_config) for p_type, p_config in config_per_platform(config, DOMAIN) if p_type ] class Provider(ABC): """Represent a single STT provider.""" hass: HomeAssistant | None = None name: str | None = None @property @abstractmethod def supported_languages(self) -> list[str]: """Return a list of supported languages.""" @property @abstractmethod def supported_formats(self) -> list[AudioFormats]: """Return a list of supported formats.""" @property @abstractmethod def supported_codecs(self) -> list[AudioCodecs]: """Return a list of supported codecs.""" @property @abstractmethod def supported_bit_rates(self) -> list[AudioBitRates]: """Return a list of supported bit rates.""" @property @abstractmethod def supported_sample_rates(self) -> list[AudioSampleRates]: """Return a list of supported sample rates.""" @property @abstractmethod def supported_channels(self) -> list[AudioChannels]: """Return a list of supported channels.""" @abstractmethod async def async_process_audio_stream( self, metadata: SpeechMetadata, stream: AsyncIterable[bytes] ) -> SpeechResult: """Process an audio stream to STT service. Only streaming of content are allow! """ @callback def check_metadata(self, metadata: SpeechMetadata) -> bool: """Check if given metadata supported by this provider.""" if ( metadata.language not in self.supported_languages or metadata.format not in self.supported_formats or metadata.codec not in self.supported_codecs or metadata.bit_rate not in self.supported_bit_rates or metadata.sample_rate not in self.supported_sample_rates or metadata.channel not in self.supported_channels ): return False return True