Support setting Amazon Polly engine in service call (#120226)
parent
d6bd4312ab
commit
822660732b
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
import logging
|
||||
from typing import Any, Final
|
||||
|
||||
|
@ -114,6 +115,8 @@ def get_engine(
|
|||
|
||||
all_voices: dict[str, dict[str, str]] = {}
|
||||
|
||||
all_engines: dict[str, set[str]] = defaultdict(set)
|
||||
|
||||
all_voices_req = polly_client.describe_voices()
|
||||
|
||||
for voice in all_voices_req.get("Voices", []):
|
||||
|
@ -124,8 +127,12 @@ def get_engine(
|
|||
language_code: str | None = voice.get("LanguageCode")
|
||||
if language_code is not None and language_code not in supported_languages:
|
||||
supported_languages.append(language_code)
|
||||
for engine in voice.get("SupportedEngines"):
|
||||
all_engines[engine].add(voice_id)
|
||||
|
||||
return AmazonPollyProvider(polly_client, config, supported_languages, all_voices)
|
||||
return AmazonPollyProvider(
|
||||
polly_client, config, supported_languages, all_voices, all_engines
|
||||
)
|
||||
|
||||
|
||||
class AmazonPollyProvider(Provider):
|
||||
|
@ -137,13 +144,16 @@ class AmazonPollyProvider(Provider):
|
|||
config: ConfigType,
|
||||
supported_languages: list[str],
|
||||
all_voices: dict[str, dict[str, str]],
|
||||
all_engines: dict[str, set[str]],
|
||||
) -> None:
|
||||
"""Initialize Amazon Polly provider for TTS."""
|
||||
self.client = polly_client
|
||||
self.config = config
|
||||
self.supported_langs = supported_languages
|
||||
self.all_voices = all_voices
|
||||
self.all_engines = all_engines
|
||||
self.default_voice: str = self.config[CONF_VOICE]
|
||||
self.default_engine: str = self.config[CONF_ENGINE]
|
||||
self.name = "Amazon Polly"
|
||||
|
||||
@property
|
||||
|
@ -159,12 +169,12 @@ class AmazonPollyProvider(Provider):
|
|||
@property
|
||||
def default_options(self) -> dict[str, str]:
|
||||
"""Return dict include default options."""
|
||||
return {CONF_VOICE: self.default_voice}
|
||||
return {CONF_VOICE: self.default_voice, CONF_ENGINE: self.default_engine}
|
||||
|
||||
@property
|
||||
def supported_options(self) -> list[str]:
|
||||
"""Return a list of supported options."""
|
||||
return [CONF_VOICE]
|
||||
return [CONF_VOICE, CONF_ENGINE]
|
||||
|
||||
def get_tts_audio(
|
||||
self,
|
||||
|
@ -179,9 +189,14 @@ class AmazonPollyProvider(Provider):
|
|||
_LOGGER.error("%s does not support the %s language", voice_id, language)
|
||||
return None, None
|
||||
|
||||
engine = options.get(CONF_ENGINE, self.default_engine)
|
||||
if voice_id not in self.all_engines[engine]:
|
||||
_LOGGER.error("%s does not support the %s engine", voice_id, engine)
|
||||
return None, None
|
||||
|
||||
_LOGGER.debug("Requesting TTS file for text: %s", message)
|
||||
resp = self.client.synthesize_speech(
|
||||
Engine=self.config[CONF_ENGINE],
|
||||
Engine=engine,
|
||||
OutputFormat=self.config[CONF_OUTPUT_FORMAT],
|
||||
SampleRate=self.config[CONF_SAMPLE_RATE],
|
||||
Text=message,
|
||||
|
|
Loading…
Reference in New Issue