Support setting Amazon Polly engine in service call (#120226)

pull/125151/head
Jakob Schlyter 2024-09-03 15:45:37 +02:00 committed by GitHub
parent d6bd4312ab
commit 822660732b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 19 additions and 4 deletions

View File

@ -2,6 +2,7 @@
from __future__ import annotations
from collections import defaultdict
import logging
from typing import Any, Final
@ -114,6 +115,8 @@ def get_engine(
all_voices: dict[str, dict[str, str]] = {}
all_engines: dict[str, set[str]] = defaultdict(set)
all_voices_req = polly_client.describe_voices()
for voice in all_voices_req.get("Voices", []):
@ -124,8 +127,12 @@ def get_engine(
language_code: str | None = voice.get("LanguageCode")
if language_code is not None and language_code not in supported_languages:
supported_languages.append(language_code)
for engine in voice.get("SupportedEngines"):
all_engines[engine].add(voice_id)
return AmazonPollyProvider(polly_client, config, supported_languages, all_voices)
return AmazonPollyProvider(
polly_client, config, supported_languages, all_voices, all_engines
)
class AmazonPollyProvider(Provider):
@ -137,13 +144,16 @@ class AmazonPollyProvider(Provider):
config: ConfigType,
supported_languages: list[str],
all_voices: dict[str, dict[str, str]],
all_engines: dict[str, set[str]],
) -> None:
"""Initialize Amazon Polly provider for TTS."""
self.client = polly_client
self.config = config
self.supported_langs = supported_languages
self.all_voices = all_voices
self.all_engines = all_engines
self.default_voice: str = self.config[CONF_VOICE]
self.default_engine: str = self.config[CONF_ENGINE]
self.name = "Amazon Polly"
@property
@ -159,12 +169,12 @@ class AmazonPollyProvider(Provider):
@property
def default_options(self) -> dict[str, str]:
"""Return dict include default options."""
return {CONF_VOICE: self.default_voice}
return {CONF_VOICE: self.default_voice, CONF_ENGINE: self.default_engine}
@property
def supported_options(self) -> list[str]:
"""Return a list of supported options."""
return [CONF_VOICE]
return [CONF_VOICE, CONF_ENGINE]
def get_tts_audio(
self,
@ -179,9 +189,14 @@ class AmazonPollyProvider(Provider):
_LOGGER.error("%s does not support the %s language", voice_id, language)
return None, None
engine = options.get(CONF_ENGINE, self.default_engine)
if voice_id not in self.all_engines[engine]:
_LOGGER.error("%s does not support the %s engine", voice_id, engine)
return None, None
_LOGGER.debug("Requesting TTS file for text: %s", message)
resp = self.client.synthesize_speech(
Engine=self.config[CONF_ENGINE],
Engine=engine,
OutputFormat=self.config[CONF_OUTPUT_FORMAT],
SampleRate=self.config[CONF_SAMPLE_RATE],
Text=message,