Add type hints to TTS provider (#78285)
Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>pull/78331/head
parent
9d47160e68
commit
55e59b778c
|
@ -11,7 +11,7 @@ import mimetypes
|
|||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Optional, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from aiohttp import web
|
||||
import mutagen
|
||||
|
@ -49,8 +49,6 @@ from homeassistant.util.yaml import load_yaml
|
|||
|
||||
from .const import DOMAIN
|
||||
|
||||
# mypy: allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
TtsAudioType = tuple[Optional[str], Optional[bytes]]
|
||||
|
@ -86,7 +84,7 @@ _RE_VOICE_FILE = re.compile(r"([a-f0-9]{40})_([^_]+)_([^_]+)_([a-z_]+)\.[a-z0-9]
|
|||
KEY_PATTERN = "{0}_{1}_{2}_{3}"
|
||||
|
||||
|
||||
def _deprecated_platform(value):
|
||||
def _deprecated_platform(value: str) -> str:
|
||||
"""Validate if platform is deprecated."""
|
||||
if value == "google":
|
||||
raise vol.Invalid(
|
||||
|
@ -253,7 +251,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
|||
if setup_tasks:
|
||||
await asyncio.wait(setup_tasks)
|
||||
|
||||
async def async_platform_discovered(platform, info):
|
||||
async def async_platform_discovered(
|
||||
platform: str, info: dict[str, Any] | None
|
||||
) -> None:
|
||||
"""Handle for discovered platform."""
|
||||
await async_setup_platform(platform, discovery_info=info)
|
||||
|
||||
|
@ -327,7 +327,7 @@ class SpeechManager:
|
|||
"""Read file cache and delete files."""
|
||||
self.mem_cache = {}
|
||||
|
||||
def remove_files():
|
||||
def remove_files() -> None:
|
||||
"""Remove files from filesystem."""
|
||||
for filename in self.file_cache.values():
|
||||
try:
|
||||
|
@ -365,7 +365,11 @@ class SpeechManager:
|
|||
|
||||
# Languages
|
||||
language = language or provider.default_language
|
||||
if language is None or language not in provider.supported_languages:
|
||||
if (
|
||||
language is None
|
||||
or provider.supported_languages is None
|
||||
or language not in provider.supported_languages
|
||||
):
|
||||
raise HomeAssistantError(f"Not supported language {language}")
|
||||
|
||||
# Options
|
||||
|
@ -583,33 +587,33 @@ class Provider:
|
|||
name: str | None = None
|
||||
|
||||
@property
|
||||
def default_language(self):
|
||||
def default_language(self) -> str | None:
|
||||
"""Return the default language."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def supported_languages(self):
|
||||
def supported_languages(self) -> list[str] | None:
|
||||
"""Return a list of supported languages."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def supported_options(self):
|
||||
"""Return a list of supported options like voice, emotionen."""
|
||||
def supported_options(self) -> list[str] | None:
|
||||
"""Return a list of supported options like voice, emotions."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def default_options(self):
|
||||
def default_options(self) -> dict[str, Any] | None:
|
||||
"""Return a dict include default options."""
|
||||
return None
|
||||
|
||||
def get_tts_audio(
|
||||
self, message: str, language: str, options: dict | None = None
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from provider."""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def async_get_tts_audio(
|
||||
self, message: str, language: str, options: dict | None = None
|
||||
self, message: str, language: str, options: dict[str, Any] | None = None
|
||||
) -> TtsAudioType:
|
||||
"""Load tts audio file from provider.
|
||||
|
||||
|
|
|
@ -1,20 +1,22 @@
|
|||
"""Support notifications through TTS service."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.notify import PLATFORM_SCHEMA, BaseNotificationService
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_NAME
|
||||
from homeassistant.core import split_entity_id
|
||||
from homeassistant.core import HomeAssistant, split_entity_id
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||
|
||||
from . import ATTR_LANGUAGE, ATTR_MESSAGE, DOMAIN
|
||||
|
||||
CONF_MEDIA_PLAYER = "media_player"
|
||||
CONF_TTS_SERVICE = "tts_service"
|
||||
|
||||
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
||||
|
@ -27,7 +29,11 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
|
|||
)
|
||||
|
||||
|
||||
async def async_get_service(hass, config, discovery_info=None):
|
||||
async def async_get_service(
|
||||
hass: HomeAssistant,
|
||||
config: ConfigType,
|
||||
discovery_info: DiscoveryInfoType | None = None,
|
||||
) -> TTSNotificationService:
|
||||
"""Return the notify service."""
|
||||
|
||||
return TTSNotificationService(config)
|
||||
|
@ -36,13 +42,13 @@ async def async_get_service(hass, config, discovery_info=None):
|
|||
class TTSNotificationService(BaseNotificationService):
|
||||
"""The TTS Notification Service."""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: ConfigType) -> None:
|
||||
"""Initialize the service."""
|
||||
_, self._tts_service = split_entity_id(config[CONF_TTS_SERVICE])
|
||||
self._media_player = config[CONF_MEDIA_PLAYER]
|
||||
self._language = config.get(ATTR_LANGUAGE)
|
||||
|
||||
async def async_send_message(self, message="", **kwargs):
|
||||
async def async_send_message(self, message: str = "", **kwargs: Any) -> None:
|
||||
"""Call TTS service to speak the notification."""
|
||||
_LOGGER.debug("%s '%s' on %s", self._tts_service, message, self._media_player)
|
||||
|
||||
|
|
|
@ -2127,6 +2127,35 @@ _INHERITANCE_MATCH: dict[str, list[ClassTypeHintMatch]] = {
|
|||
],
|
||||
),
|
||||
],
|
||||
"tts": [
|
||||
ClassTypeHintMatch(
|
||||
base_class="Provider",
|
||||
matches=[
|
||||
TypeHintMatch(
|
||||
function_name="default_language",
|
||||
return_type=["str", None],
|
||||
),
|
||||
TypeHintMatch(
|
||||
function_name="supported_languages",
|
||||
return_type=["list[str]", None],
|
||||
),
|
||||
TypeHintMatch(
|
||||
function_name="supported_options",
|
||||
return_type=["list[str]", None],
|
||||
),
|
||||
TypeHintMatch(
|
||||
function_name="default_options",
|
||||
return_type=["dict[str, Any]", None],
|
||||
),
|
||||
TypeHintMatch(
|
||||
function_name="get_tts_audio",
|
||||
arg_types={1: "str", 2: "str", 3: "dict[str, Any] | None"},
|
||||
return_type="TtsAudioType",
|
||||
has_async_counterpart=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
"update": [
|
||||
ClassTypeHintMatch(
|
||||
base_class="Entity",
|
||||
|
|
Loading…
Reference in New Issue