diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 3360a940..0f8c4760 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -1,4 +1,5 @@ import os +import torch from TTS.config import check_config_and_model_args, get_from_config_or_model_args, load_config, register_config from TTS.trainer import Trainer, TrainingArgs @@ -53,15 +54,22 @@ def main(): else: config.num_speakers = speaker_manager.num_speakers elif check_config_and_model_args(config, "use_d_vector_file", True): - speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file")) + if check_config_and_model_args(config, "use_speaker_encoder_as_loss", True): + speaker_manager = SpeakerManager( + d_vectors_file_path=config.model_args.d_vector_file, + encoder_model_path=config.model_args.speaker_encoder_model_path, + encoder_config_path=config.model_args.speaker_encoder_config_path, + use_cuda=torch.cuda.is_available(), + ) + else: + speaker_manager = SpeakerManager(d_vectors_file_path=get_from_config_or_model_args(config, "d_vector_file")) + config.num_speakers = speaker_manager.num_speakers if hasattr(config, "model_args"): config.model_args.num_speakers = speaker_manager.num_speakers - else: - config.num_speakers = speaker_manager.num_speakers else: speaker_manager = None - if hasattr(config, "use_language_embedding") and config.use_language_embedding: + if check_config_and_model_args(config, "use_language_embedding", True): language_manager = LanguageManager(config=config) if hasattr(config, "model_args"): config.model_args.num_languages = language_manager.num_languages diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 8b09fdf9..b2e4be9e 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -5,7 +5,7 @@ from typing import Dict, List, Tuple import torch -# import torchaudio +import torchaudio from coqpit import Coqpit from torch import nn from torch.cuda.amp.autocast_mode import autocast @@ -419,21 +419,12 @@ class Vits(BaseTTS): hasattr(self.speaker_manager.speaker_encoder, "audio_config") and self.config.audio["sample_rate"] != self.speaker_manager.speaker_encoder.audio_config["sample_rate"] ): - # TODO: change this with torchaudio Resample - raise RuntimeError( - " [!] To use the speaker consistency loss (SCL) you need to have matching sample rates between the TTS model ({}) and the speaker encoder ({})!".format( - self.config.audio["sample_rate"], - self.speaker_manager.speaker_encoder.audio_config["sample_rate"], - ) - ) - # pylint: disable=W0101,W0105 - """ self.audio_transform = torchaudio.transforms.Resample( + self.audio_transform = torchaudio.transforms.Resample( orig_freq=self.audio_config["sample_rate"], new_freq=self.speaker_manager.speaker_encoder.audio_config["sample_rate"], ) - else: - self.audio_transform = None - """ + else: + self.audio_transform = None def _init_speaker_embedding(self): # pylint: disable=attribute-defined-outside-init @@ -458,6 +449,7 @@ class Vits(BaseTTS): self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) if self.args.use_language_embedding and self.language_manager: + print(" > initialization of language-embedding layers.") self.num_languages = self.language_manager.num_languages self.embedded_language_dim = self.args.embedded_language_dim self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim) @@ -643,8 +635,8 @@ class Vits(BaseTTS): # resample audio to speaker encoder sample_rate # pylint: disable=W0105 - """if self.audio_transform is not None: - wavs_batch = self.audio_transform(wavs_batch)""" + if self.audio_transform is not None: + wavs_batch = self.audio_transform(wavs_batch) pred_embs = self.speaker_manager.speaker_encoder.forward(wavs_batch, l2_norm=True)