diff --git a/TTS/tts/configs/glow_tts_config.py b/TTS/tts/configs/glow_tts_config.py index caf2f71b..97fd3577 100644 --- a/TTS/tts/configs/glow_tts_config.py +++ b/TTS/tts/configs/glow_tts_config.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from typing import List from TTS.tts.configs.shared_configs import BaseTTSConfig @@ -167,3 +168,14 @@ class GlowTTSConfig(BaseTTSConfig): min_seq_len: int = 3 max_seq_len: int = 500 r: int = 1 # DO NOT CHANGE - TODO: make this immutable once coqpit implements it. + + # testing + test_sentences: List[str] = field( + default_factory=lambda: [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963.", + ] + ) diff --git a/TTS/tts/configs/speedy_speech_config.py b/TTS/tts/configs/speedy_speech_config.py index 7c9bb8e4..ba561c89 100644 --- a/TTS/tts/configs/speedy_speech_config.py +++ b/TTS/tts/configs/speedy_speech_config.py @@ -119,7 +119,7 @@ class SpeedySpeechConfig(BaseTTSConfig): hidden_channels=128, num_speakers=0, positional_encoding=True, - detach_duration_predictor=True + detach_duration_predictor=True, ) # multi-speaker settings diff --git a/TTS/tts/layers/glow_tts/encoder.py b/TTS/tts/layers/glow_tts/encoder.py index 50bcc451..36ed668b 100644 --- a/TTS/tts/layers/glow_tts/encoder.py +++ b/TTS/tts/layers/glow_tts/encoder.py @@ -165,9 +165,9 @@ class Encoder(nn.Module): # set duration predictor input if g is not None: g_exp = g.expand(-1, -1, x.size(-1)) - x_dp = torch.cat([torch.detach(x), g_exp], 1) + x_dp = torch.cat([x.detach(), g_exp], 1) else: - x_dp = torch.detach(x) + x_dp = x.detach() # final projection layer x_m = self.proj_m(x) * x_mask if not self.mean_only: diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 72e7d8d5..f465c638 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -427,11 +427,11 @@ class GlowTTSLoss(torch.nn.Module): return_dict = {} # flow loss - neg log likelihood pz = torch.sum(scales) + 0.5 * torch.sum(torch.exp(-2 * scales) * (z - means) ** 2) - log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (torch.sum(y_lengths) * z.shape[1]) + log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (torch.sum(y_lengths) * z.shape[2]) # duration loss - MSE - # loss_dur = torch.sum((o_dur_log - o_attn_dur)**2) / torch.sum(x_lengths) + loss_dur = torch.sum((o_dur_log - o_attn_dur) ** 2) / torch.sum(x_lengths) # duration loss - huber loss - loss_dur = torch.nn.functional.smooth_l1_loss(o_dur_log, o_attn_dur, reduction="sum") / torch.sum(x_lengths) + # loss_dur = torch.nn.functional.smooth_l1_loss(o_dur_log, o_attn_dur, reduction="sum") / torch.sum(x_lengths) return_dict["loss"] = log_mle + loss_dur return_dict["log_mle"] = log_mle return_dict["loss_dur"] = loss_dur diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index e643c69f..2e94659e 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -2,6 +2,7 @@ import math import torch from torch import nn +from torch.cuda.amp.autocast_mode import autocast from torch.nn import functional as F from TTS.tts.configs import GlowTTSConfig @@ -68,6 +69,8 @@ class GlowTTS(BaseTTS): # TODO: make this adjustable self.c_in_channels = 256 + self.run_data_dep_init = config.data_dep_init_steps > 0 + self.encoder = Encoder( self.num_chars, out_channels=self.out_channels, @@ -131,6 +134,18 @@ class GlowTTS(BaseTTS): o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask return y_mean, y_log_scale, o_attn_dur + def unlock_act_norm_layers(self): + """Unlock activation normalization layers for data depended initalization.""" + for f in self.decoder.flows: + if getattr(f, "set_ddi", False): + f.set_ddi(True) + + def lock_act_norm_layers(self): + """Lock activation normalization layers.""" + for f in self.decoder.flows: + if getattr(f, "set_ddi", False): + f.set_ddi(False) + def forward( self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None} ): # pylint: disable=dangerous-default-value @@ -142,6 +157,7 @@ class GlowTTS(BaseTTS): - y_lengths::math:`B` - g: :math:`[B, C] or B` """ + # [B, T, C] -> [B, C, T] y = y.transpose(1, 2) y_max_length = y.size(2) # norm speaker embeddings @@ -157,6 +173,7 @@ class GlowTTS(BaseTTS): y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None) # create masks y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + # [B, 1, T_en, T_de] attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # decoder pass z, logdet = self.decoder(y, y_mask, g=g, reverse=False) @@ -172,7 +189,7 @@ class GlowTTS(BaseTTS): y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) attn = attn.squeeze(1).permute(0, 2, 1) outputs = { - "model_outputs": z.transpose(1, 2), + "z": z.transpose(1, 2), "logdet": logdet, "y_mean": y_mean.transpose(1, 2), "y_log_scale": y_log_scale.transpose(1, 2), @@ -319,7 +336,8 @@ class GlowTTS(BaseTTS): return outputs def train_step(self, batch: dict, criterion: nn.Module): - """Perform a single training step by fetching the right set if samples from the batch. + """A single training step. Forward pass and loss computation. Run data depended initialization for the + first `config.data_dep_init_steps` steps. Args: batch (dict): [description] @@ -332,31 +350,57 @@ class GlowTTS(BaseTTS): d_vectors = batch["d_vectors"] speaker_ids = batch["speaker_ids"] - outputs = self.forward( - text_input, - text_lengths, - mel_input, - mel_lengths, - aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids}, - ) - - loss_dict = criterion( - outputs["model_outputs"], - outputs["y_mean"], - outputs["y_log_scale"], - outputs["logdet"], - mel_lengths, - outputs["durations_log"], - outputs["total_durations_log"], - text_lengths, - ) + if self.run_data_dep_init and self.training: + # compute data-dependent initialization of activation norm layers + self.unlock_act_norm_layers() + with torch.no_grad(): + _ = self.forward( + text_input, + text_lengths, + mel_input, + mel_lengths, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids}, + ) + outputs = None + loss_dict = None + self.lock_act_norm_layers() + else: + # normal training step + outputs = self.forward( + text_input, + text_lengths, + mel_input, + mel_lengths, + aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids}, + ) + with autocast(enabled=False): # avoid mixed_precision in criterion + loss_dict = criterion( + outputs["z"].float(), + outputs["y_mean"].float(), + outputs["y_log_scale"].float(), + outputs["logdet"].float(), + mel_lengths, + outputs["durations_log"].float(), + outputs["total_durations_log"].float(), + text_lengths, + ) return outputs, loss_dict def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use - model_outputs = outputs["model_outputs"] alignments = outputs["alignments"] + text_input = batch["text_input"] + text_lengths = batch["text_lengths"] mel_input = batch["mel_input"] + d_vectors = batch["d_vectors"] + speaker_ids = batch["speaker_ids"] + + # model runs reverse flow to predict spectrograms + pred_outputs = self.inference( + text_input[:1], + aux_input={"x_lengths": text_lengths[:1], "d_vectors": d_vectors, "speaker_ids": speaker_ids}, + ) + model_outputs = pred_outputs["model_outputs"] pred_spec = model_outputs[0].data.cpu().numpy() gt_spec = mel_input[0].data.cpu().numpy() @@ -393,26 +437,29 @@ class GlowTTS(BaseTTS): test_figures = {} test_sentences = self.config.test_sentences aux_inputs = self.get_aux_input() - for idx, sen in enumerate(test_sentences): - outputs = synthesis( - self, - sen, - self.config, - "cuda" in str(next(self.parameters()).device), - ap, - speaker_id=aux_inputs["speaker_id"], - d_vector=aux_inputs["d_vector"], - style_wav=aux_inputs["style_wav"], - enable_eos_bos_chars=self.config.enable_eos_bos_chars, - use_griffin_lim=True, - do_trim_silence=False, - ) + if len(test_sentences) == 0: + print(" | [!] No test sentences provided.") + else: + for idx, sen in enumerate(test_sentences): + outputs = synthesis( + self, + sen, + self.config, + "cuda" in str(next(self.parameters()).device), + ap, + speaker_id=aux_inputs["speaker_id"], + d_vector=aux_inputs["d_vector"], + style_wav=aux_inputs["style_wav"], + enable_eos_bos_chars=self.config.enable_eos_bos_chars, + use_griffin_lim=True, + do_trim_silence=False, + ) - test_audios["{}-audio".format(idx)] = outputs["wav"] - test_figures["{}-prediction".format(idx)] = plot_spectrogram( - outputs["outputs"]["model_outputs"], ap, output_fig=False - ) - test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False) + test_audios["{}-audio".format(idx)] = outputs["wav"] + test_figures["{}-prediction".format(idx)] = plot_spectrogram( + outputs["outputs"]["model_outputs"], ap, output_fig=False + ) + test_figures["{}-alignment".format(idx)] = plot_alignment(outputs["alignments"], output_fig=False) return test_figures, test_audios def preprocess(self, y, y_lengths, y_max_length, attn=None): @@ -441,3 +488,7 @@ class GlowTTS(BaseTTS): from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel return GlowTTSLoss() + + def on_train_step_start(self, trainer): + """Decide on every training step wheter enable/disable data depended initialization.""" + self.run_data_dep_init = trainer.total_steps_done < self.data_dep_init_steps diff --git a/recipes/ljspeech/glow_tts/train_glowtts.py b/recipes/ljspeech/glow_tts/train_glowtts.py index fbd54e88..5d71f4ed 100644 --- a/recipes/ljspeech/glow_tts/train_glowtts.py +++ b/recipes/ljspeech/glow_tts/train_glowtts.py @@ -15,13 +15,13 @@ config = GlowTTSConfig( run_eval=True, test_delay_epochs=-1, epochs=1000, - text_cleaner="english_cleaners", - use_phonemes=False, + text_cleaner="phoneme_cleaners", + use_phonemes=True, phoneme_language="en-us", phoneme_cache_path=os.path.join(output_path, "phoneme_cache"), print_step=25, - print_eval=True, - mixed_precision=False, + print_eval=False, + mixed_precision=True, output_path=output_path, datasets=[dataset_config], ) diff --git a/tests/tts_tests/test_glow_tts.py b/tests/tts_tests/test_glow_tts.py index 171f2cdc..e139562c 100644 --- a/tests/tts_tests/test_glow_tts.py +++ b/tests/tts_tests/test_glow_tts.py @@ -63,7 +63,7 @@ class GlowTTSTrainTest(unittest.TestCase): optimizer.zero_grad() outputs = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, None) loss_dict = criterion( - outputs["model_outputs"], + outputs["z"], outputs["y_mean"], outputs["y_log_scale"], outputs["logdet"],