core/homeassistant/components/stt/legacy.py

175 lines
5.0 KiB
Python
Raw Normal View History

"""Handle legacy speech-to-text platforms."""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, Coroutine
from dataclasses import dataclass
import logging
from typing import Any
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import config_per_platform, discovery
from homeassistant.helpers.typing import ConfigType
from homeassistant.setup import async_prepare_setup_platform
from .const import (
DATA_PROVIDERS,
DOMAIN,
AudioBitRates,
AudioChannels,
AudioCodecs,
AudioFormats,
AudioSampleRates,
SpeechResultState,
)
_LOGGER = logging.getLogger(__name__)
@callback
def async_default_provider(hass: HomeAssistant) -> str | None:
"""Return the domain of the default provider."""
if "cloud" in hass.data[DATA_PROVIDERS]:
return "cloud"
return next(iter(hass.data[DATA_PROVIDERS]), None)
@callback
def async_get_provider(
hass: HomeAssistant, domain: str | None = None
) -> Provider | None:
"""Return provider."""
if domain:
return hass.data[DATA_PROVIDERS].get(domain)
provider = async_default_provider(hass)
return hass.data[DATA_PROVIDERS][provider] if provider is not None else None
@callback
def async_setup_legacy(
hass: HomeAssistant, config: ConfigType
) -> list[Coroutine[Any, Any, None]]:
"""Set up legacy speech-to-text providers."""
providers = hass.data[DATA_PROVIDERS] = {}
async def async_setup_platform(p_type, p_config=None, discovery_info=None):
"""Set up an STT platform."""
if p_config is None:
p_config = {}
platform = await async_prepare_setup_platform(hass, config, DOMAIN, p_type)
if platform is None:
_LOGGER.error("Unknown speech-to-text platform specified")
return
try:
provider = await platform.async_get_engine(hass, p_config, discovery_info)
provider.name = p_type
provider.hass = hass
providers[provider.name] = provider
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error setting up platform: %s", p_type)
return
# Add discovery support
async def async_platform_discovered(platform, info):
"""Handle for discovered platform."""
await async_setup_platform(platform, discovery_info=info)
discovery.async_listen_platform(hass, DOMAIN, async_platform_discovered)
return [
async_setup_platform(p_type, p_config)
for p_type, p_config in config_per_platform(config, DOMAIN)
]
@dataclass
class SpeechMetadata:
"""Metadata of audio stream."""
language: str
format: AudioFormats
codec: AudioCodecs
bit_rate: AudioBitRates
sample_rate: AudioSampleRates
channel: AudioChannels
def __post_init__(self) -> None:
"""Finish initializing the metadata."""
self.bit_rate = AudioBitRates(int(self.bit_rate))
self.sample_rate = AudioSampleRates(int(self.sample_rate))
self.channel = AudioChannels(int(self.channel))
@dataclass
class SpeechResult:
"""Result of audio Speech."""
text: str | None
result: SpeechResultState
class Provider(ABC):
"""Represent a single STT provider."""
hass: HomeAssistant | None = None
name: str | None = None
@property
@abstractmethod
def supported_languages(self) -> list[str]:
"""Return a list of supported languages."""
@property
@abstractmethod
def supported_formats(self) -> list[AudioFormats]:
"""Return a list of supported formats."""
@property
@abstractmethod
def supported_codecs(self) -> list[AudioCodecs]:
"""Return a list of supported codecs."""
@property
@abstractmethod
def supported_bit_rates(self) -> list[AudioBitRates]:
"""Return a list of supported bit rates."""
@property
@abstractmethod
def supported_sample_rates(self) -> list[AudioSampleRates]:
"""Return a list of supported sample rates."""
@property
@abstractmethod
def supported_channels(self) -> list[AudioChannels]:
"""Return a list of supported channels."""
@abstractmethod
async def async_process_audio_stream(
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
) -> SpeechResult:
"""Process an audio stream to STT service.
Only streaming of content are allow!
"""
@callback
def check_metadata(self, metadata: SpeechMetadata) -> bool:
"""Check if given metadata supported by this provider."""
if (
metadata.language not in self.supported_languages
or metadata.format not in self.supported_formats
or metadata.codec not in self.supported_codecs
or metadata.bit_rate not in self.supported_bit_rates
or metadata.sample_rate not in self.supported_sample_rates
or metadata.channel not in self.supported_channels
):
return False
return True