Add type hints to TTS provider (#78285)

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
pull/78331/head
epenet 2022-09-12 23:29:55 +02:00 committed by GitHub
parent 9d47160e68
commit 55e59b778c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 20 deletions

View File

@ -11,7 +11,7 @@ import mimetypes
import os
from pathlib import Path
import re
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from aiohttp import web
import mutagen
@ -49,8 +49,6 @@ from homeassistant.util.yaml import load_yaml
from .const import DOMAIN
# mypy: allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__)
TtsAudioType = tuple[Optional[str], Optional[bytes]]
@ -86,7 +84,7 @@ _RE_VOICE_FILE = re.compile(r"([a-f0-9]{40})_([^_]+)_([^_]+)_([a-z_]+)\.[a-z0-9]
KEY_PATTERN = "{0}_{1}_{2}_{3}"
def _deprecated_platform(value):
def _deprecated_platform(value: str) -> str:
"""Validate if platform is deprecated."""
if value == "google":
raise vol.Invalid(
@ -253,7 +251,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
if setup_tasks:
await asyncio.wait(setup_tasks)
async def async_platform_discovered(platform, info):
async def async_platform_discovered(
platform: str, info: dict[str, Any] | None
) -> None:
"""Handle for discovered platform."""
await async_setup_platform(platform, discovery_info=info)
@ -327,7 +327,7 @@ class SpeechManager:
"""Read file cache and delete files."""
self.mem_cache = {}
def remove_files():
def remove_files() -> None:
"""Remove files from filesystem."""
for filename in self.file_cache.values():
try:
@ -365,7 +365,11 @@ class SpeechManager:
# Languages
language = language or provider.default_language
if language is None or language not in provider.supported_languages:
if (
language is None
or provider.supported_languages is None
or language not in provider.supported_languages
):
raise HomeAssistantError(f"Not supported language {language}")
# Options
@ -583,33 +587,33 @@ class Provider:
name: str | None = None
@property
def default_language(self):
def default_language(self) -> str | None:
"""Return the default language."""
return None
@property
def supported_languages(self):
def supported_languages(self) -> list[str] | None:
"""Return a list of supported languages."""
return None
@property
def supported_options(self):
"""Return a list of supported options like voice, emotionen."""
def supported_options(self) -> list[str] | None:
"""Return a list of supported options like voice, emotions."""
return None
@property
def default_options(self):
def default_options(self) -> dict[str, Any] | None:
"""Return a dict include default options."""
return None
def get_tts_audio(
self, message: str, language: str, options: dict | None = None
self, message: str, language: str, options: dict[str, Any] | None = None
) -> TtsAudioType:
"""Load tts audio file from provider."""
raise NotImplementedError()
async def async_get_tts_audio(
self, message: str, language: str, options: dict | None = None
self, message: str, language: str, options: dict[str, Any] | None = None
) -> TtsAudioType:
"""Load tts audio file from provider.

View File

@ -1,20 +1,22 @@
"""Support notifications through TTS service."""
from __future__ import annotations
import logging
from typing import Any
import voluptuous as vol
from homeassistant.components.notify import PLATFORM_SCHEMA, BaseNotificationService
from homeassistant.const import ATTR_ENTITY_ID, CONF_NAME
from homeassistant.core import split_entity_id
from homeassistant.core import HomeAssistant, split_entity_id
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import ATTR_LANGUAGE, ATTR_MESSAGE, DOMAIN
CONF_MEDIA_PLAYER = "media_player"
CONF_TTS_SERVICE = "tts_service"
# mypy: allow-untyped-calls, allow-untyped-defs, no-check-untyped-defs
_LOGGER = logging.getLogger(__name__)
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
@ -27,7 +29,11 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(
)
async def async_get_service(hass, config, discovery_info=None):
async def async_get_service(
hass: HomeAssistant,
config: ConfigType,
discovery_info: DiscoveryInfoType | None = None,
) -> TTSNotificationService:
"""Return the notify service."""
return TTSNotificationService(config)
@ -36,13 +42,13 @@ async def async_get_service(hass, config, discovery_info=None):
class TTSNotificationService(BaseNotificationService):
"""The TTS Notification Service."""
def __init__(self, config):
def __init__(self, config: ConfigType) -> None:
"""Initialize the service."""
_, self._tts_service = split_entity_id(config[CONF_TTS_SERVICE])
self._media_player = config[CONF_MEDIA_PLAYER]
self._language = config.get(ATTR_LANGUAGE)
async def async_send_message(self, message="", **kwargs):
async def async_send_message(self, message: str = "", **kwargs: Any) -> None:
"""Call TTS service to speak the notification."""
_LOGGER.debug("%s '%s' on %s", self._tts_service, message, self._media_player)

View File

@ -2127,6 +2127,35 @@ _INHERITANCE_MATCH: dict[str, list[ClassTypeHintMatch]] = {
],
),
],
"tts": [
ClassTypeHintMatch(
base_class="Provider",
matches=[
TypeHintMatch(
function_name="default_language",
return_type=["str", None],
),
TypeHintMatch(
function_name="supported_languages",
return_type=["list[str]", None],
),
TypeHintMatch(
function_name="supported_options",
return_type=["list[str]", None],
),
TypeHintMatch(
function_name="default_options",
return_type=["dict[str, Any]", None],
),
TypeHintMatch(
function_name="get_tts_audio",
arg_types={1: "str", 2: "str", 3: "dict[str, Any] | None"},
return_type="TtsAudioType",
has_async_counterpart=True,
),
],
),
],
"update": [
ClassTypeHintMatch(
base_class="Entity",