"""Provide functionality to STT.""" from abc import ABC, abstractmethod import asyncio import logging from typing import Dict, List, Optional from aiohttp import StreamReader, web from aiohttp.hdrs import istr from aiohttp.web_exceptions import ( HTTPBadRequest, HTTPNotFound, HTTPUnsupportedMediaType, ) import attr from homeassistant.components.http import HomeAssistantView from homeassistant.core import callback from homeassistant.helpers import config_per_platform, discovery from homeassistant.helpers.typing import HomeAssistantType from homeassistant.setup import async_prepare_setup_platform from .const import ( DOMAIN, AudioBitRates, AudioChannels, AudioCodecs, AudioFormats, AudioSampleRates, SpeechResultState, ) # mypy: allow-untyped-defs, no-check-untyped-defs _LOGGER = logging.getLogger(__name__) async def async_setup(hass: HomeAssistantType, config): """Set up STT.""" providers = {} async def async_setup_platform(p_type, p_config=None, discovery_info=None): """Set up a TTS platform.""" if p_config is None: p_config = {} platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type) if platform is None: return try: provider = await platform.async_get_engine(hass, p_config, discovery_info) if provider is None: _LOGGER.error("Error setting up platform %s", p_type) return provider.name = p_type provider.hass = hass providers[provider.name] = provider except Exception: # pylint: disable=broad-except _LOGGER.exception("Error setting up platform: %s", p_type) return setup_tasks = [ async_setup_platform(p_type, p_config) for p_type, p_config in config_per_platform(config, DOMAIN) ] if setup_tasks: await asyncio.wait(setup_tasks) # Add discovery support async def async_platform_discovered(platform, info): """Handle for discovered platform.""" await async_setup_platform(platform, discovery_info=info) discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered) hass.http.register_view(SpeechToTextView(providers)) return True @attr.s class SpeechMetadata: """Metadata of audio stream.""" language: str = attr.ib() format: AudioFormats = attr.ib() codec: AudioCodecs = attr.ib() bit_rate: AudioBitRates = attr.ib(converter=int) sample_rate: AudioSampleRates = attr.ib(converter=int) channel: AudioChannels = attr.ib(converter=int) @attr.s class SpeechResult: """Result of audio Speech.""" text: Optional[str] = attr.ib() result: SpeechResultState = attr.ib() class Provider(ABC): """Represent a single STT provider.""" hass: Optional[HomeAssistantType] = None name: Optional[str] = None @property @abstractmethod def supported_languages(self) -> List[str]: """Return a list of supported languages.""" @property @abstractmethod def supported_formats(self) -> List[AudioFormats]: """Return a list of supported formats.""" @property @abstractmethod def supported_codecs(self) -> List[AudioCodecs]: """Return a list of supported codecs.""" @property @abstractmethod def supported_bit_rates(self) -> List[AudioBitRates]: """Return a list of supported bit rates.""" @property @abstractmethod def supported_sample_rates(self) -> List[AudioSampleRates]: """Return a list of supported sample rates.""" @property @abstractmethod def supported_channels(self) -> List[AudioChannels]: """Return a list of supported channels.""" @abstractmethod async def async_process_audio_stream( self, metadata: SpeechMetadata, stream: StreamReader ) -> SpeechResult: """Process an audio stream to STT service. Only streaming of content are allow! """ @callback def check_metadata(self, metadata: SpeechMetadata) -> bool: """Check if given metadata supported by this provider.""" if ( metadata.language not in self.supported_languages or metadata.format not in self.supported_formats or metadata.codec not in self.supported_codecs or metadata.bit_rate not in self.supported_bit_rates or metadata.sample_rate not in self.supported_sample_rates or metadata.channel not in self.supported_channels ): return False return True class SpeechToTextView(HomeAssistantView): """STT view to generate a text from audio stream.""" requires_auth = True url = "/api/stt/{provider}" name = "api:stt:provider" def __init__(self, providers: Dict[str, Provider]) -> None: """Initialize a tts view.""" self.providers = providers @staticmethod def _metadata_from_header(request: web.Request) -> Optional[SpeechMetadata]: """Extract metadata from header. X-Speech-Content: format=wav; codec=pcm; sample_rate=16000; bit_rate=16; channel=1; language=de_de """ try: data = request.headers[istr("X-Speech-Content")].split(";") except KeyError: _LOGGER.warning("Missing X-Speech-Content") return None # Convert Header data args = {} for value in data: value = value.strip() args[value.partition("=")[0]] = value.partition("=")[2] try: return SpeechMetadata(**args) except TypeError as err: _LOGGER.warning("Wrong format of X-Speech-Content: %s", err) return None async def post(self, request: web.Request, provider: str) -> web.Response: """Convert Speech (audio) to text.""" if provider not in self.providers: raise HTTPNotFound() stt_provider: Provider = self.providers[provider] # Get metadata metadata = self._metadata_from_header(request) if not metadata: raise HTTPBadRequest() # Check format if not stt_provider.check_metadata(metadata): raise HTTPUnsupportedMediaType() # Process audio stream result = await stt_provider.async_process_audio_stream( metadata, request.content ) # Return result return self.json(attr.asdict(result)) async def get(self, request: web.Request, provider: str) -> web.Response: """Return provider specific audio information.""" if provider not in self.providers: raise HTTPNotFound() stt_provider: Provider = self.providers[provider] return self.json( { "languages": stt_provider.supported_languages, "formats": stt_provider.supported_formats, "codecs": stt_provider.supported_codecs, "sample_rates": stt_provider.supported_sample_rates, "bit_rates": stt_provider.supported_bit_rates, "channels": stt_provider.supported_channels, } )