117 lines
3.7 KiB
Python
117 lines
3.7 KiB
Python
"""Support for the cloud for speech to text service."""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections.abc import AsyncIterable
|
|
import logging
|
|
|
|
from hass_nabucasa import Cloud
|
|
from hass_nabucasa.voice import STT_LANGUAGES, VoiceError
|
|
|
|
from homeassistant.components.stt import (
|
|
AudioBitRates,
|
|
AudioChannels,
|
|
AudioCodecs,
|
|
AudioFormats,
|
|
AudioSampleRates,
|
|
SpeechMetadata,
|
|
SpeechResult,
|
|
SpeechResultState,
|
|
SpeechToTextEntity,
|
|
)
|
|
from homeassistant.config_entries import ConfigEntry
|
|
from homeassistant.const import Platform
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
|
|
|
from .assist_pipeline import async_migrate_cloud_pipeline_engine
|
|
from .client import CloudClient
|
|
from .const import DATA_PLATFORMS_SETUP, DOMAIN, STT_ENTITY_UNIQUE_ID
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
async def async_setup_entry(
|
|
hass: HomeAssistant,
|
|
config_entry: ConfigEntry,
|
|
async_add_entities: AddEntitiesCallback,
|
|
) -> None:
|
|
"""Set up Home Assistant Cloud speech platform via config entry."""
|
|
stt_platform_loaded: asyncio.Event = hass.data[DATA_PLATFORMS_SETUP][Platform.STT]
|
|
stt_platform_loaded.set()
|
|
cloud: Cloud[CloudClient] = hass.data[DOMAIN]
|
|
async_add_entities([CloudProviderEntity(cloud)])
|
|
|
|
|
|
class CloudProviderEntity(SpeechToTextEntity):
|
|
"""Home Assistant Cloud speech API provider."""
|
|
|
|
_attr_name = "Home Assistant Cloud"
|
|
_attr_unique_id = STT_ENTITY_UNIQUE_ID
|
|
|
|
def __init__(self, cloud: Cloud[CloudClient]) -> None:
|
|
"""Initialize cloud Speech to text entity."""
|
|
self.cloud = cloud
|
|
|
|
@property
|
|
def supported_languages(self) -> list[str]:
|
|
"""Return a list of supported languages."""
|
|
return STT_LANGUAGES
|
|
|
|
@property
|
|
def supported_formats(self) -> list[AudioFormats]:
|
|
"""Return a list of supported formats."""
|
|
return [AudioFormats.WAV, AudioFormats.OGG]
|
|
|
|
@property
|
|
def supported_codecs(self) -> list[AudioCodecs]:
|
|
"""Return a list of supported codecs."""
|
|
return [AudioCodecs.PCM, AudioCodecs.OPUS]
|
|
|
|
@property
|
|
def supported_bit_rates(self) -> list[AudioBitRates]:
|
|
"""Return a list of supported bitrates."""
|
|
return [AudioBitRates.BITRATE_16]
|
|
|
|
@property
|
|
def supported_sample_rates(self) -> list[AudioSampleRates]:
|
|
"""Return a list of supported samplerates."""
|
|
return [AudioSampleRates.SAMPLERATE_16000]
|
|
|
|
@property
|
|
def supported_channels(self) -> list[AudioChannels]:
|
|
"""Return a list of supported channels."""
|
|
return [AudioChannels.CHANNEL_MONO]
|
|
|
|
async def async_added_to_hass(self) -> None:
|
|
"""Run when entity is about to be added to hass."""
|
|
await async_migrate_cloud_pipeline_engine(
|
|
self.hass, platform=Platform.STT, engine_id=self.entity_id
|
|
)
|
|
|
|
async def async_process_audio_stream(
|
|
self, metadata: SpeechMetadata, stream: AsyncIterable[bytes]
|
|
) -> SpeechResult:
|
|
"""Process an audio stream to STT service."""
|
|
content_type = (
|
|
f"audio/{metadata.format!s}; codecs=audio/{metadata.codec!s};"
|
|
" samplerate=16000"
|
|
)
|
|
|
|
# Process STT
|
|
try:
|
|
result = await self.cloud.voice.process_stt(
|
|
stream=stream,
|
|
content_type=content_type,
|
|
language=metadata.language,
|
|
)
|
|
except VoiceError as err:
|
|
_LOGGER.error("Voice error: %s", err)
|
|
return SpeechResult(None, SpeechResultState.ERROR)
|
|
|
|
# Return Speech as Text
|
|
return SpeechResult(
|
|
result.text,
|
|
SpeechResultState.SUCCESS if result.success else SpeechResultState.ERROR,
|
|
)
|