Merge pull request #197 from idiap/api

Expand Python API capabilities
pull/4115/head^2
Enno Hermann 2024-12-06 18:02:54 +01:00 committed by GitHub
commit b545ab8b80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 164 additions and 180 deletions

View File

@ -1,12 +1,14 @@
"""Coqui TTS Python API."""
import logging
import tempfile
import warnings
from pathlib import Path
from typing import Optional
from torch import nn
from TTS.config import load_config
from TTS.utils.audio.numpy_transforms import save_wav
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
@ -19,13 +21,19 @@ class TTS(nn.Module):
def __init__(
self,
model_name: str = "",
model_path: str = None,
config_path: str = None,
vocoder_path: str = None,
vocoder_config_path: str = None,
*,
model_path: Optional[str] = None,
config_path: Optional[str] = None,
vocoder_name: Optional[str] = None,
vocoder_path: Optional[str] = None,
vocoder_config_path: Optional[str] = None,
encoder_path: Optional[str] = None,
encoder_config_path: Optional[str] = None,
speakers_file_path: Optional[str] = None,
language_ids_file_path: Optional[str] = None,
progress_bar: bool = True,
gpu=False,
):
gpu: bool = False,
) -> None:
"""🐸TTS python interface that allows to load and use the released models.
Example with a multi-speaker model:
@ -35,31 +43,36 @@ class TTS(nn.Module):
>>> tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path="output.wav")
Example with a single-speaker model:
>>> tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False, gpu=False)
>>> tts = TTS(model_name="tts_models/de/thorsten/tacotron2-DDC", progress_bar=False)
>>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav")
Example loading a model from a path:
>>> tts = TTS(model_path="/path/to/checkpoint_100000.pth", config_path="/path/to/config.json", progress_bar=False, gpu=False)
>>> tts = TTS(model_path="/path/to/checkpoint_100000.pth", config_path="/path/to/config.json", progress_bar=False)
>>> tts.tts_to_file(text="Ich bin eine Testnachricht.", file_path="output.wav")
Example voice cloning with YourTTS in English, French and Portuguese:
>>> tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=True)
>>> tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False).to("cuda")
>>> tts.tts_to_file("This is voice cloning.", speaker_wav="my/cloning/audio.wav", language="en", file_path="thisisit.wav")
>>> tts.tts_to_file("C'est le clonage de la voix.", speaker_wav="my/cloning/audio.wav", language="fr", file_path="thisisit.wav")
>>> tts.tts_to_file("Isso é clonagem de voz.", speaker_wav="my/cloning/audio.wav", language="pt", file_path="thisisit.wav")
Example Fairseq TTS models (uses ISO language codes in https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html):
>>> tts = TTS(model_name="tts_models/eng/fairseq/vits", progress_bar=False, gpu=True)
>>> tts = TTS(model_name="tts_models/eng/fairseq/vits", progress_bar=False).to("cuda")
>>> tts.tts_to_file("This is a test.", file_path="output.wav")
Args:
model_name (str, optional): Model name to load. You can list models by ```tts.models```. Defaults to None.
model_path (str, optional): Path to the model checkpoint. Defaults to None.
config_path (str, optional): Path to the model config. Defaults to None.
vocoder_name (str, optional): Pre-trained vocoder to use. Defaults to None, i.e. using the default vocoder.
vocoder_path (str, optional): Path to the vocoder checkpoint. Defaults to None.
vocoder_config_path (str, optional): Path to the vocoder config. Defaults to None.
progress_bar (bool, optional): Whether to pring a progress bar while downloading a model. Defaults to True.
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
encoder_path: Path to speaker encoder checkpoint. Default to None.
encoder_config_path: Path to speaker encoder config file. Defaults to None.
speakers_file_path: JSON file for multi-speaker model. Defaults to None.
language_ids_file_path: JSON file for multilingual model. Defaults to None
progress_bar (bool, optional): Whether to print a progress bar while downloading a model. Defaults to True.
gpu (bool, optional): Enable/disable GPU. Defaults to False. DEPRECATED, use TTS(...).to("cuda")
"""
super().__init__()
self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar)
@ -67,34 +80,45 @@ class TTS(nn.Module):
self.synthesizer = None
self.voice_converter = None
self.model_name = ""
self.vocoder_path = vocoder_path
self.vocoder_config_path = vocoder_config_path
self.encoder_path = encoder_path
self.encoder_config_path = encoder_config_path
self.speakers_file_path = speakers_file_path
self.language_ids_file_path = language_ids_file_path
if gpu:
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")
if model_name is not None and len(model_name) > 0:
if "tts_models" in model_name:
self.load_tts_model_by_name(model_name, gpu)
self.load_tts_model_by_name(model_name, vocoder_name, gpu=gpu)
elif "voice_conversion_models" in model_name:
self.load_vc_model_by_name(model_name, gpu)
self.load_vc_model_by_name(model_name, gpu=gpu)
# To allow just TTS("xtts")
else:
self.load_model_by_name(model_name, gpu)
self.load_model_by_name(model_name, vocoder_name, gpu=gpu)
if model_path:
self.load_tts_model_by_path(
model_path, config_path, vocoder_path=vocoder_path, vocoder_config=vocoder_config_path, gpu=gpu
)
self.load_tts_model_by_path(model_path, config_path, gpu=gpu)
@property
def models(self):
def models(self) -> list[str]:
return self.manager.list_tts_models()
@property
def is_multi_speaker(self):
if hasattr(self.synthesizer.tts_model, "speaker_manager") and self.synthesizer.tts_model.speaker_manager:
def is_multi_speaker(self) -> bool:
if (
self.synthesizer is not None
and hasattr(self.synthesizer.tts_model, "speaker_manager")
and self.synthesizer.tts_model.speaker_manager
):
return self.synthesizer.tts_model.speaker_manager.num_speakers > 1
return False
@property
def is_multi_lingual(self):
def is_multi_lingual(self) -> bool:
# Not sure what sets this to None, but applied a fix to prevent crashing.
if (
isinstance(self.model_name, str)
@ -103,51 +127,63 @@ class TTS(nn.Module):
and ("xtts" in self.config.model or "languages" in self.config and len(self.config.languages) > 1)
):
return True
if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager:
if (
self.synthesizer is not None
and hasattr(self.synthesizer.tts_model, "language_manager")
and self.synthesizer.tts_model.language_manager
):
return self.synthesizer.tts_model.language_manager.num_languages > 1
return False
@property
def speakers(self):
def speakers(self) -> list[str]:
if not self.is_multi_speaker:
return None
return self.synthesizer.tts_model.speaker_manager.speaker_names
@property
def languages(self):
def languages(self) -> list[str]:
if not self.is_multi_lingual:
return None
return self.synthesizer.tts_model.language_manager.language_names
@staticmethod
def get_models_file_path():
def get_models_file_path() -> Path:
return Path(__file__).parent / ".models.json"
@staticmethod
def list_models():
def list_models() -> list[str]:
return ModelManager(models_file=TTS.get_models_file_path(), progress_bar=False).list_models()
def download_model_by_name(self, model_name: str):
def download_model_by_name(
self, model_name: str, vocoder_name: Optional[str] = None
) -> tuple[Optional[str], Optional[str], Optional[str]]:
model_path, config_path, model_item = self.manager.download_model(model_name)
if "fairseq" in model_name or (model_item is not None and isinstance(model_item["model_url"], list)):
# return model directory if there are multiple files
# we assume that the model knows how to load itself
return None, None, None, None, model_path
return None, None, model_path
if model_item.get("default_vocoder") is None:
return model_path, config_path, None, None, None
vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"])
return model_path, config_path, vocoder_path, vocoder_config_path, None
return model_path, config_path, None
if vocoder_name is None:
vocoder_name = model_item["default_vocoder"]
vocoder_path, vocoder_config_path, _ = self.manager.download_model(vocoder_name)
# A local vocoder model will take precedence if specified via vocoder_path
if self.vocoder_path is None or self.vocoder_config_path is None:
self.vocoder_path = vocoder_path
self.vocoder_config_path = vocoder_config_path
return model_path, config_path, None
def load_model_by_name(self, model_name: str, gpu: bool = False):
def load_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False) -> None:
"""Load one of the 🐸TTS models by name.
Args:
model_name (str): Model name to load. You can list models by ```tts.models```.
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
self.load_tts_model_by_name(model_name, gpu)
self.load_tts_model_by_name(model_name, vocoder_name, gpu=gpu)
def load_vc_model_by_name(self, model_name: str, gpu: bool = False):
def load_vc_model_by_name(self, model_name: str, *, gpu: bool = False) -> None:
"""Load one of the voice conversion models by name.
Args:
@ -155,12 +191,12 @@ class TTS(nn.Module):
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
self.model_name = model_name
model_path, config_path, _, _, model_dir = self.download_model_by_name(model_name)
model_path, config_path, model_dir = self.download_model_by_name(model_name)
self.voice_converter = Synthesizer(
vc_checkpoint=model_path, vc_config=config_path, model_dir=model_dir, use_cuda=gpu
)
def load_tts_model_by_name(self, model_name: str, gpu: bool = False):
def load_tts_model_by_name(self, model_name: str, vocoder_name: Optional[str] = None, *, gpu: bool = False) -> None:
"""Load one of 🐸TTS models by name.
Args:
@ -172,7 +208,7 @@ class TTS(nn.Module):
self.synthesizer = None
self.model_name = model_name
model_path, config_path, vocoder_path, vocoder_config_path, model_dir = self.download_model_by_name(model_name)
model_path, config_path, model_dir = self.download_model_by_name(model_name, vocoder_name)
# init synthesizer
# None values are fetch from the model
@ -181,17 +217,15 @@ class TTS(nn.Module):
tts_config_path=config_path,
tts_speakers_file=None,
tts_languages_file=None,
vocoder_checkpoint=vocoder_path,
vocoder_config=vocoder_config_path,
encoder_checkpoint=None,
encoder_config=None,
vocoder_checkpoint=self.vocoder_path,
vocoder_config=self.vocoder_config_path,
encoder_checkpoint=self.encoder_path,
encoder_config=self.encoder_config_path,
model_dir=model_dir,
use_cuda=gpu,
)
def load_tts_model_by_path(
self, model_path: str, config_path: str, vocoder_path: str = None, vocoder_config: str = None, gpu: bool = False
):
def load_tts_model_by_path(self, model_path: str, config_path: str, *, gpu: bool = False) -> None:
"""Load a model from a path.
Args:
@ -205,22 +239,22 @@ class TTS(nn.Module):
self.synthesizer = Synthesizer(
tts_checkpoint=model_path,
tts_config_path=config_path,
tts_speakers_file=None,
tts_languages_file=None,
vocoder_checkpoint=vocoder_path,
vocoder_config=vocoder_config,
encoder_checkpoint=None,
encoder_config=None,
tts_speakers_file=self.speakers_file_path,
tts_languages_file=self.language_ids_file_path,
vocoder_checkpoint=self.vocoder_path,
vocoder_config=self.vocoder_config_path,
encoder_checkpoint=self.encoder_path,
encoder_config=self.encoder_config_path,
use_cuda=gpu,
)
def _check_arguments(
self,
speaker: str = None,
language: str = None,
speaker_wav: str = None,
emotion: str = None,
speed: float = None,
speaker: Optional[str] = None,
language: Optional[str] = None,
speaker_wav: Optional[str] = None,
emotion: Optional[str] = None,
speed: Optional[float] = None,
**kwargs,
) -> None:
"""Check if the arguments are valid for the model."""
@ -280,10 +314,6 @@ class TTS(nn.Module):
speaker_name=speaker,
language_name=language,
speaker_wav=speaker_wav,
reference_wav=None,
style_wav=None,
style_text=None,
reference_speaker_name=None,
split_sentences=split_sentences,
**kwargs,
)
@ -301,7 +331,7 @@ class TTS(nn.Module):
file_path: str = "output.wav",
split_sentences: bool = True,
**kwargs,
):
) -> str:
"""Convert text to speech.
Args:
@ -367,6 +397,7 @@ class TTS(nn.Module):
source_wav: str,
target_wav: str,
file_path: str = "output.wav",
pipe_out=None,
) -> str:
"""Voice conversion with FreeVC. Convert source wav to target speaker.
@ -377,9 +408,11 @@ class TTS(nn.Module):
Path to the target wav file.
file_path (str, optional):
Output file path. Defaults to "output.wav".
pipe_out (BytesIO, optional):
Flag to stdout the generated TTS wav file for shell pipe.
"""
wav = self.voice_conversion(source_wav=source_wav, target_wav=target_wav)
save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate)
self.voice_converter.save_wav(wav=wav, path=file_path, pipe_out=pipe_out)
return file_path
def tts_with_vc(
@ -432,7 +465,8 @@ class TTS(nn.Module):
file_path: str = "output.wav",
speaker: str = None,
split_sentences: bool = True,
):
pipe_out=None,
) -> str:
"""Convert text to speech with voice conversion and save to file.
Check `tts_with_vc` for more details.
@ -455,8 +489,11 @@ class TTS(nn.Module):
Split text into sentences, synthesize them separately and concatenate the file audio.
Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only
applicable to the 🐸TTS models. Defaults to True.
pipe_out (BytesIO, optional):
Flag to stdout the generated TTS wav file for shell pipe.
"""
wav = self.tts_with_vc(
text=text, language=language, speaker_wav=speaker_wav, speaker=speaker, split_sentences=split_sentences
)
save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate)
self.voice_converter.save_wav(wav=wav, path=file_path, pipe_out=pipe_out)
return file_path

View File

@ -9,8 +9,6 @@ import sys
from argparse import RawTextHelpFormatter
# pylint: disable=redefined-outer-name, unused-argument
from pathlib import Path
from TTS.utils.generic_utils import ConsoleFormatter, setup_logger
logger = logging.getLogger(__name__)
@ -253,11 +251,6 @@ def parse_args() -> argparse.Namespace:
action="store_true",
)
# aux args
parser.add_argument(
"--save_spectogram",
action="store_true",
help="Save raw spectogram for further (vocoder) processing in out_path.",
)
parser.add_argument(
"--reference_wav",
type=str,
@ -317,7 +310,8 @@ def parse_args() -> argparse.Namespace:
return args
def main():
def main() -> None:
"""Entry point for `tts` command line interface."""
setup_logger("TTS", level=logging.INFO, screen=True, formatter=ConsoleFormatter())
args = parse_args()
@ -325,12 +319,11 @@ def main():
with contextlib.redirect_stdout(None if args.pipe_out else sys.stdout):
# Late-import to make things load faster
from TTS.api import TTS
from TTS.utils.manage import ModelManager
from TTS.utils.synthesizer import Synthesizer
# load model manager
path = Path(__file__).parent / "../.models.json"
manager = ModelManager(path, progress_bar=args.progress_bar)
manager = ModelManager(models_file=TTS.get_models_file_path(), progress_bar=args.progress_bar)
tts_path = None
tts_config_path = None
@ -344,12 +337,12 @@ def main():
vc_config_path = None
model_dir = None
# CASE1 #list : list pre-trained TTS models
# 1) List pre-trained TTS models
if args.list_models:
manager.list_models()
sys.exit()
# CASE2 #info : model info for pre-trained TTS models
# 2) Info about pre-trained TTS models (without loading a model)
if args.model_info_by_idx:
model_query = args.model_info_by_idx
manager.model_info_by_idx(model_query)
@ -360,91 +353,50 @@ def main():
manager.model_info_by_full_name(model_query_full_name)
sys.exit()
# CASE3: load pre-trained model paths
if args.model_name is not None and not args.model_path:
model_path, config_path, model_item = manager.download_model(args.model_name)
# tts model
if model_item["model_type"] == "tts_models":
tts_path = model_path
tts_config_path = config_path
if args.vocoder_name is None and "default_vocoder" in model_item:
args.vocoder_name = model_item["default_vocoder"]
# voice conversion model
if model_item["model_type"] == "voice_conversion_models":
vc_path = model_path
vc_config_path = config_path
# tts model with multiple files to be loaded from the directory path
if model_item.get("author", None) == "fairseq" or isinstance(model_item["model_url"], list):
model_dir = model_path
tts_path = None
tts_config_path = None
args.vocoder_name = None
# load vocoder
if args.vocoder_name is not None and not args.vocoder_path:
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
# CASE4: set custom model paths
if args.model_path is not None:
tts_path = args.model_path
tts_config_path = args.config_path
speakers_file_path = args.speakers_file_path
language_ids_file_path = args.language_ids_file_path
if args.vocoder_path is not None:
vocoder_path = args.vocoder_path
vocoder_config_path = args.vocoder_config_path
if args.encoder_path is not None:
encoder_path = args.encoder_path
encoder_config_path = args.encoder_config_path
# 3) Load a model for further info or TTS/VC
device = args.device
if args.use_cuda:
device = "cuda"
# load models
synthesizer = Synthesizer(
tts_checkpoint=tts_path,
tts_config_path=tts_config_path,
tts_speakers_file=speakers_file_path,
tts_languages_file=language_ids_file_path,
vocoder_checkpoint=vocoder_path,
vocoder_config=vocoder_config_path,
encoder_checkpoint=encoder_path,
encoder_config=encoder_config_path,
vc_checkpoint=vc_path,
vc_config=vc_config_path,
model_dir=model_dir,
voice_dir=args.voice_dir,
# A local model will take precedence if specified via modeL_path
model_name = args.model_name if args.model_path is None else None
api = TTS(
model_name=model_name,
model_path=args.model_path,
config_path=args.config_path,
vocoder_name=args.vocoder_name,
vocoder_path=args.vocoder_path,
vocoder_config_path=args.vocoder_config_path,
encoder_path=args.encoder_path,
encoder_config_path=args.encoder_config_path,
speakers_file_path=args.speakers_file_path,
language_ids_file_path=args.language_ids_file_path,
progress_bar=args.progress_bar,
).to(device)
# query speaker ids of a multi-speaker model.
if args.list_speaker_idxs:
if synthesizer.tts_model.speaker_manager is None:
if not api.is_multi_speaker:
logger.info("Model only has a single speaker.")
return
logger.info(
"Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
)
logger.info(list(synthesizer.tts_model.speaker_manager.name_to_id.keys()))
logger.info(api.speakers)
return
# query langauge ids of a multi-lingual model.
if args.list_language_idxs:
if synthesizer.tts_model.language_manager is None:
if not api.is_multi_lingual:
logger.info("Monolingual model.")
return
logger.info(
"Available language ids: (Set --language_idx flag to one of these values to use the multi-lingual model."
)
logger.info(synthesizer.tts_model.language_manager.name_to_id)
logger.info(api.languages)
return
# check the arguments against a multi-speaker model.
if synthesizer.tts_speakers_file and (not args.speaker_idx and not args.speaker_wav):
if api.is_multi_speaker and (not args.speaker_idx and not args.speaker_wav):
logger.error(
"Looks like you use a multi-speaker model. Define `--speaker_idx` to "
"select the target speaker. You can list the available speakers for this model by `--list_speaker_idxs`."
@ -455,31 +407,29 @@ def main():
if args.text:
logger.info("Text: %s", args.text)
# kick it
if tts_path is not None:
wav = synthesizer.tts(
args.text,
speaker_name=args.speaker_idx,
language_name=args.language_idx,
if args.text is not None:
api.tts_to_file(
text=args.text,
speaker=args.speaker_idx,
language=args.language_idx,
speaker_wav=args.speaker_wav,
pipe_out=pipe_out,
file_path=args.out_path,
reference_wav=args.reference_wav,
style_wav=args.capacitron_style_wav,
style_text=args.capacitron_style_text,
reference_speaker_name=args.reference_speaker_idx,
voice_dir=args.voice_dir,
)
elif vc_path is not None:
wav = synthesizer.voice_conversion(
logger.info("Saved TTS output to %s", args.out_path)
elif args.source_wav is not None and args.target_wav is not None:
api.voice_conversion_to_file(
source_wav=args.source_wav,
target_wav=args.target_wav,
file_path=args.out_path,
pipe_out=pipe_out,
)
elif model_dir is not None:
wav = synthesizer.tts(
args.text, speaker_name=args.speaker_idx, language_name=args.language_idx, speaker_wav=args.speaker_wav
)
# save the results
synthesizer.save_wav(wav, args.out_path, pipe_out=pipe_out)
logger.info("Saved output to %s", args.out_path)
logger.info("Saved VC output to %s", args.out_path)
if __name__ == "__main__":

View File

@ -37,7 +37,7 @@ from TTS.api import TTS
# Load the model to GPU
# Bark is really slow on CPU, so we recommend using GPU.
tts = TTS("tts_models/multilingual/multi-dataset/bark", gpu=True)
tts = TTS("tts_models/multilingual/multi-dataset/bark").to("cuda")
# Cloning a new speaker
@ -57,7 +57,7 @@ tts.tts_to_file(text="Hello, my name is Manmay , how are you?",
# random speaker
tts = TTS("tts_models/multilingual/multi-dataset/bark", gpu=True)
tts = TTS("tts_models/multilingual/multi-dataset/bark").to("cuda")
tts.tts_to_file("hello world", file_path="out.wav")
```

View File

@ -118,7 +118,7 @@ You can optionally disable sentence splitting for better coherence but more VRAM
```python
from TTS.api import TTS
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to("cuda")
# generate speech by cloning a voice using default settings
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
@ -137,15 +137,15 @@ You can pass multiple audio files to the `speaker_wav` argument for better voice
from TTS.api import TTS
# using the default version set in 🐸TTS
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to("cuda")
# using a specific version
# 👀 see the branch names for versions on https://huggingface.co/coqui/XTTS-v2/tree/main
# ❗some versions might be incompatible with the API
tts = TTS("xtts_v2.0.2", gpu=True)
tts = TTS("xtts_v2.0.2").to("cuda")
# getting the latest XTTS_v2
tts = TTS("xtts", gpu=True)
tts = TTS("xtts").to("cuda")
# generate speech by cloning a voice using default settings
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
@ -160,7 +160,7 @@ You can do inference using one of the available speakers using the following cod
```python
from TTS.api import TTS
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=True)
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to("cuda")
# generate speech by cloning a voice using default settings
tts.tts_to_file(text="It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",

View File

@ -34,30 +34,27 @@ def run_models(offset=0, step=1):
# download and run the model
speaker_files = glob.glob(local_download_dir + "/speaker*")
language_files = glob.glob(local_download_dir + "/language*")
language_id = ""
speaker_arg = ""
language_arg = ""
if len(speaker_files) > 0:
# multi-speaker model
if "speaker_ids" in speaker_files[0]:
speaker_manager = SpeakerManager(speaker_id_file_path=speaker_files[0])
elif "speakers" in speaker_files[0]:
speaker_manager = SpeakerManager(d_vectors_file_path=speaker_files[0])
# multi-lingual model - Assuming multi-lingual models are also multi-speaker
if len(language_files) > 0 and "language_ids" in language_files[0]:
language_manager = LanguageManager(language_ids_file_path=language_files[0])
language_id = language_manager.language_names[0]
speaker_id = list(speaker_manager.name_to_id.keys())[0]
run_cli(
f"tts --model_name {model_name} "
f'--text "This is an example." --out_path "{output_path}" --speaker_idx "{speaker_id}" --language_idx "{language_id}" --no-progress_bar'
)
else:
# single-speaker model
run_cli(
f"tts --model_name {model_name} "
f'--text "This is an example." --out_path "{output_path}" --no-progress_bar'
)
speakers = list(speaker_manager.name_to_id.keys())
if len(speakers) > 1:
speaker_arg = f'--speaker_idx "{speakers[0]}"'
if len(language_files) > 0 and "language_ids" in language_files[0]:
# multi-lingual model
language_manager = LanguageManager(language_ids_file_path=language_files[0])
languages = language_manager.language_names
if len(languages) > 1:
language_arg = f'--language_idx "{languages[0]}"'
run_cli(
f'tts --model_name {model_name} --text "This is an example." '
f'--out_path "{output_path}" {speaker_arg} {language_arg} --no-progress_bar'
)
# remove downloaded models
shutil.rmtree(local_download_dir)
shutil.rmtree(get_user_data_dir("tts"))