diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 2a2c0b71..fa63c46a 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -229,7 +229,9 @@ def main(args): # pylint: disable=redefined-outer-name ap = AudioProcessor(**c.audio) # load data instances - meta_data_train, meta_data_eval = load_tts_samples(c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size) + meta_data_train, meta_data_eval = load_tts_samples( + c.datasets, eval_split=args.eval, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) # use eval and training partitions meta_data = meta_data_train + meta_data_eval diff --git a/TTS/bin/find_unique_chars.py b/TTS/bin/find_unique_chars.py index 541e971b..4689dcad 100644 --- a/TTS/bin/find_unique_chars.py +++ b/TTS/bin/find_unique_chars.py @@ -23,7 +23,9 @@ def main(): c = load_config(args.config_path) # load all datasets - train_items, eval_items = load_tts_samples(c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size) + train_items, eval_items = load_tts_samples( + c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) items = train_items + eval_items diff --git a/TTS/bin/find_unique_phonemes.py b/TTS/bin/find_unique_phonemes.py index 8fe48b2f..0ae74bd4 100644 --- a/TTS/bin/find_unique_phonemes.py +++ b/TTS/bin/find_unique_phonemes.py @@ -40,7 +40,9 @@ def main(): c = load_config(args.config_path) # load all datasets - train_items, eval_items = load_tts_samples(c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size) + train_items, eval_items = load_tts_samples( + c.datasets, eval_split=True, eval_split_max_size=c.eval_split_max_size, eval_split_size=c.eval_split_size + ) items = train_items + eval_items print("Num items:", len(items)) diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 31813712..976b74af 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -44,7 +44,12 @@ def main(): config = register_config(config_base.model)() # load training samples - train_samples, eval_samples = load_tts_samples(config.datasets, eval_split=True, eval_split_max_size=config.eval_split_max_size, eval_split_size=config.eval_split_size) + train_samples, eval_samples = load_tts_samples( + config.datasets, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, + ) # init the model from config model = setup_model(config, train_samples + eval_samples) diff --git a/TTS/tts/datasets/__init__.py b/TTS/tts/datasets/__init__.py index dde85808..6c7c9edd 100644 --- a/TTS/tts/datasets/__init__.py +++ b/TTS/tts/datasets/__init__.py @@ -12,20 +12,20 @@ from TTS.tts.datasets.formatters import * def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): """Split a dataset into train and eval. Consider speaker distribution in multi-speaker training. - Args: -<<<<<<< HEAD - items (List[List]): - A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`. + Args: + <<<<<<< HEAD + items (List[List]): + A list of samples. Each sample is a list of `[audio_path, text, speaker_id]`. - eval_split_max_size (int): - Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). + eval_split_max_size (int): + Number maximum of samples to be used for evaluation in proportion split. Defaults to None (Disabled). - eval_split_size (float): - If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. - If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). -======= - items (List[List]): A list of samples. Each sample is a list of `[text, audio_path, speaker_id]`. ->>>>>>> Fix docstring + eval_split_size (float): + If between 0.0 and 1.0 represents the proportion of the dataset to include in the evaluation set. + If > 1, represents the absolute number of evaluation samples. Defaults to 0.01 (1%). + ======= + items (List[List]): A list of samples. Each sample is a list of `[text, audio_path, speaker_id]`. + >>>>>>> Fix docstring """ speakers = [item["speaker_name"] for item in items] is_multi_speaker = len(set(speakers)) > 1 @@ -37,7 +37,11 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): else: eval_split_size = int(len(items) * eval_split_size) - assert eval_split_size > 0, " [!] You do not have enough samples for the evaluation set. You can work around this setting the 'eval_split_size' parameter to a minimum of {}".format(1/len(items)) + assert ( + eval_split_size > 0 + ), " [!] You do not have enough samples for the evaluation set. You can work around this setting the 'eval_split_size' parameter to a minimum of {}".format( + 1 / len(items) + ) np.random.seed(0) np.random.shuffle(items) if is_multi_speaker: @@ -56,8 +60,11 @@ def split_dataset(items, eval_split_max_size=None, eval_split_size=0.01): def load_tts_samples( - datasets: Union[List[Dict], Dict], eval_split=True, formatter: Callable = None, - eval_split_max_size=None, eval_split_size=0.01 + datasets: Union[List[Dict], Dict], + eval_split=True, + formatter: Callable = None, + eval_split_max_size=None, + eval_split_size=0.01, ) -> Tuple[List[List], List[List]]: """Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided. If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based diff --git a/TTS/tts/datasets/formatters.py b/TTS/tts/datasets/formatters.py index 4592ccce..aacfc647 100644 --- a/TTS/tts/datasets/formatters.py +++ b/TTS/tts/datasets/formatters.py @@ -132,7 +132,7 @@ def ljspeech_test(root_path, meta_file, **kwargs): # pylint: disable=unused-arg speaker_id = 0 for idx, line in enumerate(ttf): # 2 samples per speaker to avoid eval split issues - if idx%2 == 0: + if idx % 2 == 0: speaker_id += 1 cols = line.split("|") wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index c30f043a..fea570a6 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -183,8 +183,8 @@ class GlowTTS(BaseTTS): if g is not None: if hasattr(self, "emb_g"): # use speaker embedding layer - if not g.size(): # if is a scalar - g = g.unsqueeze(0) # unsqueeze + if not g.size(): # if is a scalar + g = g.unsqueeze(0) # unsqueeze g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] else: # use d-vector diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 8c15103f..1ad8807f 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -14,6 +14,7 @@ from torch.cuda.amp.autocast_mode import autocast from torch.nn import functional as F from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.tts.configs.shared_configs import CharactersConfig from TTS.tts.datasets.dataset import TTSDataset, _parse_sample @@ -29,7 +30,6 @@ from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.characters import BaseCharacters, _characters, _pad, _phonemes, _punctuations from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment -from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results @@ -1481,10 +1481,12 @@ class Vits(BaseTTS): language_manager = LanguageManager.init_from_config(config) if config.model_args.speaker_encoder_model_path is not None: - speaker_manager.init_speaker_encoder(config.model_args.speaker_encoder_model_path, - config.model_args.speaker_encoder_config_path) + speaker_manager.init_speaker_encoder( + config.model_args.speaker_encoder_model_path, config.model_args.speaker_encoder_config_path + ) return Vits(new_config, ap, tokenizer, speaker_manager, language_manager) + ################################## # VITS CHARACTERS ################################## diff --git a/TTS/vocoder/models/gan.py b/TTS/vocoder/models/gan.py index 91467956..3b8a3fbe 100644 --- a/TTS/vocoder/models/gan.py +++ b/TTS/vocoder/models/gan.py @@ -7,10 +7,10 @@ from coqpit import Coqpit from torch import nn from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_fsspec -from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.vocoder.datasets.gan_dataset import GANDataset from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss from TTS.vocoder.models import setup_discriminator, setup_generator diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 95aa3cd2..c4968f1f 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -8,9 +8,9 @@ from torch import nn from torch.nn.utils import weight_norm from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.utils.io import load_fsspec -from trainer.trainer_utils import get_optimizer, get_scheduler from TTS.vocoder.datasets import WaveGradDataset from TTS.vocoder.layers.wavegrad import Conv1d, DBlock, FiLM, UBlock from TTS.vocoder.models.base_vocoder import BaseVocoder diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index 2727bbdd..0562fbf7 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -7,7 +7,7 @@ import torch from torch.utils.data import DataLoader from tests import get_tests_output_path -from TTS.tts.configs.shared_configs import BaseTTSConfig, BaseDatasetConfig +from TTS.tts.configs.shared_configs import BaseDatasetConfig, BaseTTSConfig from TTS.tts.datasets import TTSDataset, load_tts_samples from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor @@ -24,7 +24,7 @@ c.data_path = "tests/data/ljspeech/" ok_ljspeech = os.path.exists(c.data_path) dataset_config = BaseDatasetConfig( - name="ljspeech_test", # ljspeech_test to multi-speaker + name="ljspeech_test", # ljspeech_test to multi-speaker meta_file_train="metadata.csv", meta_file_val=None, path=c.data_path, @@ -106,9 +106,9 @@ class TestTTSDataset(unittest.TestCase): # make sure that the computed mels and the waveform match and correctly computed mel_new = self.ap.melspectrogram(wavs[0].squeeze().numpy()) # remove padding in mel-spectrogram - mel_dataloader = mel_input[0].T.numpy()[:, :mel_lengths[0]] + mel_dataloader = mel_input[0].T.numpy()[:, : mel_lengths[0]] # guarantee that both mel-spectrograms have the same size and that we will remove waveform padding - mel_new = mel_new[:, :mel_lengths[0]] + mel_new = mel_new[:, : mel_lengths[0]] ignore_seg = -(1 + c.audio.win_length // c.audio.hop_length) mel_diff = (mel_new[:, : mel_input.shape[1]] - mel_input[0].T.numpy())[:, 0:ignore_seg] self.assertLess(abs(mel_diff.sum()), 1e-5) diff --git a/tests/inference_tests/test_synthesizer.py b/tests/inference_tests/test_synthesizer.py index 97878574..d643cb81 100644 --- a/tests/inference_tests/test_synthesizer.py +++ b/tests/inference_tests/test_synthesizer.py @@ -1,13 +1,12 @@ import os import unittest +from tests import get_tests_output_path from TTS.config import load_config from TTS.tts.models import setup_model from TTS.utils.io import save_checkpoint from TTS.utils.synthesizer import Synthesizer -from tests import get_tests_output_path - class SynthesizerTest(unittest.TestCase): # pylint: disable=R0201 diff --git a/tests/text_tests/test_characters.py b/tests/text_tests/test_characters.py index 2ebb3bc3..8f40656a 100644 --- a/tests/text_tests/test_characters.py +++ b/tests/text_tests/test_characters.py @@ -1,13 +1,20 @@ import unittest -from TTS.tts.utils.text.characters import BaseCharacters, Graphemes, IPAPhonemes, BaseVocabulary +from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, Graphemes, IPAPhonemes # pylint: disable=protected-access + class BaseVocabularyTest(unittest.TestCase): def setUp(self): self.phonemes = IPAPhonemes() - self.base_vocab = BaseVocabulary(vocab=self.phonemes._vocab, pad=self.phonemes.pad, blank=self.phonemes.blank, bos=self.phonemes.bos, eos=self.phonemes.eos) + self.base_vocab = BaseVocabulary( + vocab=self.phonemes._vocab, + pad=self.phonemes.pad, + blank=self.phonemes.blank, + bos=self.phonemes.bos, + eos=self.phonemes.eos, + ) self.empty_vocab = BaseVocabulary({}) def test_pad_id(self): @@ -22,8 +29,8 @@ class BaseVocabularyTest(unittest.TestCase): self.assertEqual(self.empty_vocab.vocab, {}) self.assertEqual(self.base_vocab.vocab, self.phonemes._vocab) - def test_init_from_config(self): - ... + # def test_init_from_config(self): + # ... def test_num_chars(self): self.assertEqual(self.empty_vocab.num_chars, 0) diff --git a/tests/tts_tests/test_align_tts_train.py b/tests/tts_tests/test_align_tts_train.py index 6c68d8c9..85dfbbcb 100644 --- a/tests/tts_tests/test_align_tts_train.py +++ b/tests/tts_tests/test_align_tts_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.align_tts_config import AlignTTSConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -51,7 +52,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" run_cli(inference_command) diff --git a/tests/tts_tests/test_fast_pitch_speaker_emb_train.py b/tests/tts_tests/test_fast_pitch_speaker_emb_train.py index 88505988..37faf449 100644 --- a/tests/tts_tests/test_fast_pitch_speaker_emb_train.py +++ b/tests/tts_tests/test_fast_pitch_speaker_emb_train.py @@ -2,10 +2,11 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.fast_pitch_config import FastPitchConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "fast_pitch_speaker_emb_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -69,7 +70,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" continue_speakers_path = os.path.join(continue_path, "speakers.json") diff --git a/tests/tts_tests/test_fast_pitch_train.py b/tests/tts_tests/test_fast_pitch_train.py index 5a51f0bb..d2d78af4 100644 --- a/tests/tts_tests/test_fast_pitch_train.py +++ b/tests/tts_tests/test_fast_pitch_train.py @@ -2,10 +2,11 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseAudioConfig from TTS.tts.configs.fast_pitch_config import FastPitchConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -70,7 +71,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" run_cli(inference_command) diff --git a/tests/tts_tests/test_glow_tts_d-vectors_train.py b/tests/tts_tests/test_glow_tts_d-vectors_train.py index dd5e954e..14f9e4d2 100644 --- a/tests/tts_tests/test_glow_tts_d-vectors_train.py +++ b/tests/tts_tests/test_glow_tts_d-vectors_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.glow_tts_config import GlowTTSConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -56,7 +57,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" continue_speakers_path = config.d_vector_file diff --git a/tests/tts_tests/test_glow_tts_speaker_emb_train.py b/tests/tts_tests/test_glow_tts_speaker_emb_train.py index df86cf05..c327332e 100644 --- a/tests/tts_tests/test_glow_tts_speaker_emb_train.py +++ b/tests/tts_tests/test_glow_tts_speaker_emb_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.glow_tts_config import GlowTTSConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -53,7 +54,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" continue_speakers_path = os.path.join(continue_path, "speakers.json") diff --git a/tests/tts_tests/test_glow_tts_train.py b/tests/tts_tests/test_glow_tts_train.py index 3a1c4a68..b0acf004 100644 --- a/tests/tts_tests/test_glow_tts_train.py +++ b/tests/tts_tests/test_glow_tts_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.glow_tts_config import GlowTTSConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -52,7 +53,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" run_cli(inference_command) diff --git a/tests/tts_tests/test_speedy_speech_train.py b/tests/tts_tests/test_speedy_speech_train.py index 98cf8e09..9a26d253 100644 --- a/tests/tts_tests/test_speedy_speech_train.py +++ b/tests/tts_tests/test_speedy_speech_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.speedy_speech_config import SpeedySpeechConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_speedy_speech_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -51,7 +52,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example for it.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" run_cli(inference_command) diff --git a/tests/tts_tests/test_tacotron2_d-vectors_train.py b/tests/tts_tests/test_tacotron2_d-vectors_train.py index e5f83804..6b003f2c 100644 --- a/tests/tts_tests/test_tacotron2_d-vectors_train.py +++ b/tests/tts_tests/test_tacotron2_d-vectors_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron2_config import Tacotron2Config -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -56,7 +57,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" continue_speakers_path = config.d_vector_file diff --git a/tests/tts_tests/test_tacotron2_speaker_emb_train.py b/tests/tts_tests/test_tacotron2_speaker_emb_train.py index 2dd50c73..b9f4de0b 100644 --- a/tests/tts_tests/test_tacotron2_speaker_emb_train.py +++ b/tests/tts_tests/test_tacotron2_speaker_emb_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron2_config import Tacotron2Config -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -54,7 +55,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" continue_speakers_path = os.path.join(continue_path, "speakers.json") diff --git a/tests/tts_tests/test_tacotron2_train.py b/tests/tts_tests/test_tacotron2_train.py index a45065b2..8c30d9f9 100644 --- a/tests/tts_tests/test_tacotron2_train.py +++ b/tests/tts_tests/test_tacotron2_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron2_config import Tacotron2Config -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -51,7 +52,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" run_cli(inference_command) diff --git a/tests/tts_tests/test_tacotron_train.py b/tests/tts_tests/test_tacotron_train.py index 96c63162..40cd2d3d 100644 --- a/tests/tts_tests/test_tacotron_train.py +++ b/tests/tts_tests/test_tacotron_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.tacotron_config import TacotronConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -52,7 +53,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" run_cli(inference_command) diff --git a/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py b/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py index c09f8498..0c7672d7 100644 --- a/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py +++ b/tests/tts_tests/test_vits_multilingual_speaker_emb_train.py @@ -2,10 +2,11 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -85,7 +86,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech" languae_id = "en" continue_speakers_path = os.path.join(continue_path, "speakers.json") diff --git a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py index 8607a8f7..a8e2020e 100644 --- a/tests/tts_tests/test_vits_multilingual_train-d_vectors.py +++ b/tests/tts_tests/test_vits_multilingual_train-d_vectors.py @@ -2,10 +2,11 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.config.shared_configs import BaseDatasetConfig from TTS.tts.configs.vits_config import VitsConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -89,7 +90,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" languae_id = "en" continue_speakers_path = config.d_vector_file diff --git a/tests/tts_tests/test_vits_speaker_emb_train.py b/tests/tts_tests/test_vits_speaker_emb_train.py index 8a586076..c928cee4 100644 --- a/tests/tts_tests/test_vits_speaker_emb_train.py +++ b/tests/tts_tests/test_vits_speaker_emb_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.vits_config import VitsConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -60,7 +61,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") speaker_id = "ljspeech-1" continue_speakers_path = os.path.join(continue_path, "speakers.json") diff --git a/tests/tts_tests/test_vits_train.py b/tests/tts_tests/test_vits_train.py index 76c88682..003f99a8 100644 --- a/tests/tts_tests/test_vits_train.py +++ b/tests/tts_tests/test_vits_train.py @@ -2,9 +2,10 @@ import glob import os import shutil +from trainer import get_last_checkpoint + from tests import get_device_id, get_tests_output_path, run_cli from TTS.tts.configs.vits_config import VitsConfig -from trainer import get_last_checkpoint config_path = os.path.join(get_tests_output_path(), "test_model_config.json") output_path = os.path.join(get_tests_output_path(), "train_outputs") @@ -51,7 +52,7 @@ continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getm # Inference using TTS API continue_config_path = os.path.join(continue_path, "config.json") continue_restore_path, _ = get_last_checkpoint(continue_path) -out_wav_path = os.path.join(get_tests_output_path(), 'output.wav') +out_wav_path = os.path.join(get_tests_output_path(), "output.wav") inference_command = f"CUDA_VISIBLE_DEVICES='{get_device_id()}' tts --text 'This is an example.' --config_path {continue_config_path} --model_path {continue_restore_path} --out_path {out_wav_path}" run_cli(inference_command)