From 535a458f40ad8890a3b07d83fad116689a7c5553 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 25 May 2021 14:36:06 +0200 Subject: [PATCH] update Tacotron models for the trainer --- TTS/tts/configs/tacotron_config.py | 1 + TTS/tts/models/tacotron.py | 198 ++++++++++++++---- TTS/tts/models/tacotron2.py | 308 +++++++++++++++++----------- TTS/tts/models/tacotron_abstract.py | 26 ++- 4 files changed, 373 insertions(+), 160 deletions(-) diff --git a/TTS/tts/configs/tacotron_config.py b/TTS/tts/configs/tacotron_config.py index 2fc7cc78..90decaa3 100644 --- a/TTS/tts/configs/tacotron_config.py +++ b/TTS/tts/configs/tacotron_config.py @@ -126,6 +126,7 @@ class TacotronConfig(BaseTTSConfig): use_gst: bool = False gst: GSTConfig = None gst_style_input: str = None + # model specific params r: int = 2 gradual_training: List[List[int]] = None diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index c1d95a25..23bd839f 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -113,7 +113,8 @@ class Tacotron(TacotronAbstract): if self.num_speakers > 1: if not self.embeddings_per_sample: speaker_embedding_dim = 256 - self.speaker_embedding = nn.Embedding(self.num_speakers, speaker_embedding_dim) + self.speaker_embedding = nn.Embedding(self.num_speakers, + speaker_embedding_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) # speaker and gst embeddings is concat in decoder input @@ -144,7 +145,8 @@ class Tacotron(TacotronAbstract): separate_stopnet, ) self.postnet = PostCBHG(decoder_output_dim) - self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim) + self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, + postnet_output_dim) # setup prenet dropout self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference @@ -181,93 +183,203 @@ class Tacotron(TacotronAbstract): separate_stopnet, ) - def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None, speaker_embeddings=None): + def forward(self, + text, + text_lengths, + mel_specs=None, + mel_lengths=None, + cond_input=None): """ Shapes: - characters: [B, T_in] + text: [B, T_in] text_lengths: [B] mel_specs: [B, T_out, C] mel_lengths: [B] - speaker_ids: [B, 1] - speaker_embeddings: [B, C] + cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C] """ + outputs = { + 'alignments_backward': None, + 'decoder_outputs_backward': None + } input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) # B x T_in x embed_dim - inputs = self.embedding(characters) + inputs = self.embedding(text) # B x T_in x encoder_in_features encoder_outputs = self.encoder(inputs) # sequence masking - encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) + encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as( + encoder_outputs) # global style token if self.gst and self.use_gst: # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, speaker_embeddings) + encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, + cond_input['x_vectors']) # speaker embedding if self.num_speakers > 1: if not self.embeddings_per_sample: # B x 1 x speaker_embed_dim - speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] + speaker_embeddings = self.speaker_embedding(cond_input['speaker_ids'])[:, + None] else: # B x 1 x speaker_embed_dim - speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1) - encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) + speaker_embeddings = torch.unsqueeze(cond_input['x_vectors'], 1) + encoder_outputs = self._concat_speaker_embedding( + encoder_outputs, speaker_embeddings) # decoder_outputs: B x decoder_in_features x T_out # alignments: B x T_in x encoder_in_features # stop_tokens: B x T_in - decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask) + decoder_outputs, alignments, stop_tokens = self.decoder( + encoder_outputs, mel_specs, input_mask) # sequence masking if output_mask is not None: - decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs) + decoder_outputs = decoder_outputs * output_mask.unsqueeze( + 1).expand_as(decoder_outputs) # B x T_out x decoder_in_features postnet_outputs = self.postnet(decoder_outputs) # sequence masking if output_mask is not None: - postnet_outputs = postnet_outputs * output_mask.unsqueeze(2).expand_as(postnet_outputs) + postnet_outputs = postnet_outputs * output_mask.unsqueeze( + 2).expand_as(postnet_outputs) # B x T_out x posnet_dim postnet_outputs = self.last_linear(postnet_outputs) # B x T_out x decoder_in_features decoder_outputs = decoder_outputs.transpose(1, 2).contiguous() if self.bidirectional_decoder: - decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask) - return ( - decoder_outputs, - postnet_outputs, - alignments, - stop_tokens, - decoder_outputs_backward, - alignments_backward, - ) + decoder_outputs_backward, alignments_backward = self._backward_pass( + mel_specs, encoder_outputs, input_mask) + outputs['alignments_backward'] = alignments_backward + outputs['decoder_outputs_backward'] = decoder_outputs_backward if self.double_decoder_consistency: decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass( - mel_specs, encoder_outputs, alignments, input_mask - ) - return ( - decoder_outputs, - postnet_outputs, - alignments, - stop_tokens, - decoder_outputs_backward, - alignments_backward, - ) - return decoder_outputs, postnet_outputs, alignments, stop_tokens + mel_specs, encoder_outputs, alignments, input_mask) + outputs['alignments_backward'] = alignments_backward + outputs['decoder_outputs_backward'] = decoder_outputs_backward + outputs.update({ + 'postnet_outputs': postnet_outputs, + 'decoder_outputs': decoder_outputs, + 'alignments': alignments, + 'stop_tokens': stop_tokens + }) + return outputs @torch.no_grad() - def inference(self, characters, speaker_ids=None, style_mel=None, speaker_embeddings=None): - inputs = self.embedding(characters) + def inference(self, + text_input, + cond_input=None): + inputs = self.embedding(text_input) encoder_outputs = self.encoder(inputs) if self.gst and self.use_gst: # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings) + encoder_outputs = self.compute_gst(encoder_outputs, cond_input['style_mel'], + cond_input['x_vectors']) if self.num_speakers > 1: if not self.embeddings_per_sample: # B x 1 x speaker_embed_dim - speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] + speaker_embeddings = self.speaker_embedding(cond_input['speaker_ids'])[:, None] else: # B x 1 x speaker_embed_dim - speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1) - encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) - decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs) + speaker_embeddings = torch.unsqueeze(cond_input['x_vectors'], 1) + encoder_outputs = self._concat_speaker_embedding( + encoder_outputs, speaker_embeddings) + decoder_outputs, alignments, stop_tokens = self.decoder.inference( + encoder_outputs) postnet_outputs = self.postnet(decoder_outputs) postnet_outputs = self.last_linear(postnet_outputs) decoder_outputs = decoder_outputs.transpose(1, 2) - return decoder_outputs, postnet_outputs, alignments, stop_tokens + outputs = { + 'postnet_outputs': postnet_outputs, + 'decoder_outputs': decoder_outputs, + 'alignments': alignments, + 'stop_tokens': stop_tokens + } + return outputs + + def train_step(self, batch, criterion): + """Perform a single training step by fetching the right set if samples from the batch. + + Args: + batch ([type]): [description] + criterion ([type]): [description] + """ + text_input = batch['text_input'] + text_lengths = batch['text_lengths'] + mel_input = batch['mel_input'] + mel_lengths = batch['mel_lengths'] + linear_input = batch['linear_input'] + stop_targets = batch['stop_targets'] + speaker_ids = batch['speaker_ids'] + x_vectors = batch['x_vectors'] + + # forward pass model + outputs = self.forward(text_input, + text_lengths, + mel_input, + mel_lengths, + cond_input={ + 'speaker_ids': speaker_ids, + 'x_vectors': x_vectors + }) + + # set the [alignment] lengths wrt reduction factor for guided attention + if mel_lengths.max() % self.decoder.r != 0: + alignment_lengths = ( + mel_lengths + + (self.decoder.r - + (mel_lengths.max() % self.decoder.r))) // self.decoder.r + else: + alignment_lengths = mel_lengths // self.decoder.r + + cond_input = {'speaker_ids': speaker_ids, 'x_vectors': x_vectors} + outputs = self.forward(text_input, text_lengths, mel_input, + mel_lengths, cond_input) + + # compute loss + loss_dict = criterion( + outputs['postnet_outputs'], + outputs['decoder_outputs'], + mel_input, + linear_input, + outputs['stop_tokens'], + stop_targets, + mel_lengths, + outputs['decoder_outputs_backward'], + outputs['alignments'], + alignment_lengths, + outputs['alignments_backward'], + text_lengths, + ) + + # compute alignment error (the lower the better ) + align_error = 1 - alignment_diagonal_score(outputs['alignments']) + loss_dict["align_error"] = align_error + return outputs, loss_dict + + def train_log(self, ap, batch, outputs): + postnet_outputs = outputs['postnet_outputs'] + alignments = outputs['alignments'] + alignments_backward = outputs['alignments_backward'] + mel_input = batch['mel_input'] + + pred_spec = postnet_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + if self.bidirectional_decoder or self.double_decoder_consistency: + figures["alignment_backward"] = plot_alignment( + alignments_backward[0].data.cpu().numpy(), output_fig=False) + + # Sample audio + train_audio = ap.inv_spectrogram(pred_spec.T) + return figures, train_audio + + def eval_step(self, batch, criterion): + return self.train_step(batch, criterion) + + def eval_log(self, ap, batch, outputs): + return self.train_log(ap, batch, outputs) \ No newline at end of file diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index 525eb8b3..51b181e4 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -1,12 +1,15 @@ +# coding: utf-8 +import numpy as np import torch from torch import nn +from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.tts.layers.tacotron.gst_layers import GST from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet from TTS.tts.models.tacotron_abstract import TacotronAbstract -# TODO: match function arguments with tacotron class Tacotron2(TacotronAbstract): """Tacotron2 as in https://arxiv.org/abs/1712.05884 @@ -43,69 +46,52 @@ class Tacotron2(TacotronAbstract): speaker_embedding_dim (int, optional): external speaker conditioning vector channels. Defaults to None. use_gst (bool, optional): enable/disable Global style token module. gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None. + gradual_trainin (List): Gradual training schedule. If None or `[]`, no gradual training is used. + Defaults to `[]`. """ - - def __init__( - self, - num_chars, - num_speakers, - r, - postnet_output_dim=80, - decoder_output_dim=80, - attn_type="original", - attn_win=False, - attn_norm="softmax", - prenet_type="original", - prenet_dropout=True, - prenet_dropout_at_inference=False, - forward_attn=False, - trans_agent=False, - forward_attn_mask=False, - location_attn=True, - attn_K=5, - separate_stopnet=True, - bidirectional_decoder=False, - double_decoder_consistency=False, - ddc_r=None, - encoder_in_features=512, - decoder_in_features=512, - speaker_embedding_dim=None, - use_gst=False, - gst=None, - ): - super().__init__( - num_chars, - num_speakers, - r, - postnet_output_dim, - decoder_output_dim, - attn_type, - attn_win, - attn_norm, - prenet_type, - prenet_dropout, - prenet_dropout_at_inference, - forward_attn, - trans_agent, - forward_attn_mask, - location_attn, - attn_K, - separate_stopnet, - bidirectional_decoder, - double_decoder_consistency, - ddc_r, - encoder_in_features, - decoder_in_features, - speaker_embedding_dim, - use_gst, - gst, - ) + def __init__(self, + num_chars, + num_speakers, + r, + postnet_output_dim=80, + decoder_output_dim=80, + attn_type="original", + attn_win=False, + attn_norm="softmax", + prenet_type="original", + prenet_dropout=True, + prenet_dropout_at_inference=False, + forward_attn=False, + trans_agent=False, + forward_attn_mask=False, + location_attn=True, + attn_K=5, + separate_stopnet=True, + bidirectional_decoder=False, + double_decoder_consistency=False, + ddc_r=None, + encoder_in_features=512, + decoder_in_features=512, + speaker_embedding_dim=None, + use_gst=False, + gst=None, + gradual_training=[]): + super().__init__(num_chars, num_speakers, r, postnet_output_dim, + decoder_output_dim, attn_type, attn_win, attn_norm, + prenet_type, prenet_dropout, + prenet_dropout_at_inference, forward_attn, + trans_agent, forward_attn_mask, location_attn, attn_K, + separate_stopnet, bidirectional_decoder, + double_decoder_consistency, ddc_r, + encoder_in_features, decoder_in_features, + speaker_embedding_dim, use_gst, gst, gradual_training) # speaker embedding layer if self.num_speakers > 1: if not self.embeddings_per_sample: speaker_embedding_dim = 512 - self.speaker_embedding = nn.Embedding(self.num_speakers, speaker_embedding_dim) + self.speaker_embedding = nn.Embedding(self.num_speakers, + speaker_embedding_dim) self.speaker_embedding.weight.data.normal_(0, 0.3) # speaker and gst embeddings is concat in decoder input @@ -176,16 +162,24 @@ class Tacotron2(TacotronAbstract): mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) return mel_outputs, mel_outputs_postnet, alignments - def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None, speaker_embeddings=None): + def forward(self, + text, + text_lengths, + mel_specs=None, + mel_lengths=None, + cond_input=None): """ Shapes: text: [B, T_in] text_lengths: [B] mel_specs: [B, T_out, C] mel_lengths: [B] - speaker_ids: [B, 1] - speaker_embeddings: [B, C] + cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C] """ + outputs = { + 'alignments_backward': None, + 'decoder_outputs_backward': None + } # compute mask for padding # B x T_in_max (boolean) input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) @@ -195,94 +189,176 @@ class Tacotron2(TacotronAbstract): encoder_outputs = self.encoder(embedded_inputs, text_lengths) if self.gst and self.use_gst: # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, speaker_embeddings) + encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, + cond_input['x_vectors']) if self.num_speakers > 1: if not self.embeddings_per_sample: # B x 1 x speaker_embed_dim - speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] + speaker_embeddings = self.speaker_embedding(cond_input['speaker_ids'])[:, + None] else: # B x 1 x speaker_embed_dim - speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1) - encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) + speaker_embeddings = torch.unsqueeze(cond_input['x_vectors'], 1) + encoder_outputs = self._concat_speaker_embedding( + encoder_outputs, speaker_embeddings) - encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) + encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as( + encoder_outputs) # B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r - decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask) + decoder_outputs, alignments, stop_tokens = self.decoder( + encoder_outputs, mel_specs, input_mask) # sequence masking if mel_lengths is not None: - decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs) + decoder_outputs = decoder_outputs * output_mask.unsqueeze( + 1).expand_as(decoder_outputs) # B x mel_dim x T_out postnet_outputs = self.postnet(decoder_outputs) postnet_outputs = decoder_outputs + postnet_outputs # sequence masking if output_mask is not None: - postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs) + postnet_outputs = postnet_outputs * output_mask.unsqueeze( + 1).expand_as(postnet_outputs) # B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in - decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments) + decoder_outputs, postnet_outputs, alignments = self.shape_outputs( + decoder_outputs, postnet_outputs, alignments) if self.bidirectional_decoder: - decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask) - return ( - decoder_outputs, - postnet_outputs, - alignments, - stop_tokens, - decoder_outputs_backward, - alignments_backward, - ) + decoder_outputs_backward, alignments_backward = self._backward_pass( + mel_specs, encoder_outputs, input_mask) + outputs['alignments_backward'] = alignments_backward + outputs['decoder_outputs_backward'] = decoder_outputs_backward if self.double_decoder_consistency: decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass( - mel_specs, encoder_outputs, alignments, input_mask - ) - return ( - decoder_outputs, - postnet_outputs, - alignments, - stop_tokens, - decoder_outputs_backward, - alignments_backward, - ) - return decoder_outputs, postnet_outputs, alignments, stop_tokens + mel_specs, encoder_outputs, alignments, input_mask) + outputs['alignments_backward'] = alignments_backward + outputs['decoder_outputs_backward'] = decoder_outputs_backward + outputs.update({ + 'postnet_outputs': postnet_outputs, + 'decoder_outputs': decoder_outputs, + 'alignments': alignments, + 'stop_tokens': stop_tokens + }) + return outputs @torch.no_grad() - def inference(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None): + def inference(self, text, cond_input=None): embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder.inference(embedded_inputs) if self.gst and self.use_gst: # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings) + encoder_outputs = self.compute_gst(encoder_outputs, cond_input['style_mel'], + cond_input['x_vectors']) if self.num_speakers > 1: if not self.embeddings_per_sample: - speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] - speaker_embeddings = torch.unsqueeze(speaker_embeddings, 0).transpose(1, 2) - encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) + x_vector = self.speaker_embedding(cond_input['speaker_ids'])[:, None] + x_vector = torch.unsqueeze(x_vector, 0).transpose(1, 2) + else: + x_vector = cond_input - decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs) + encoder_outputs = self._concat_speaker_embedding( + encoder_outputs, x_vector) + + decoder_outputs, alignments, stop_tokens = self.decoder.inference( + encoder_outputs) postnet_outputs = self.postnet(decoder_outputs) postnet_outputs = decoder_outputs + postnet_outputs - decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments) - return decoder_outputs, postnet_outputs, alignments, stop_tokens + decoder_outputs, postnet_outputs, alignments = self.shape_outputs( + decoder_outputs, postnet_outputs, alignments) + outputs = { + 'postnet_outputs': postnet_outputs, + 'decoder_outputs': decoder_outputs, + 'alignments': alignments, + 'stop_tokens': stop_tokens + } + return outputs - def inference_truncated(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None): + def train_step(self, batch, criterion): + """Perform a single training step by fetching the right set if samples from the batch. + + Args: + batch ([type]): [description] + criterion ([type]): [description] """ - Preserve model states for continuous inference - """ - embedded_inputs = self.embedding(text).transpose(1, 2) - encoder_outputs = self.encoder.inference_truncated(embedded_inputs) + text_input = batch['text_input'] + text_lengths = batch['text_lengths'] + mel_input = batch['mel_input'] + mel_lengths = batch['mel_lengths'] + linear_input = batch['linear_input'] + stop_targets = batch['stop_targets'] + speaker_ids = batch['speaker_ids'] + x_vectors = batch['x_vectors'] - if self.gst: - # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings) + # forward pass model + outputs = self.forward(text_input, + text_lengths, + mel_input, + mel_lengths, + cond_input={ + 'speaker_ids': speaker_ids, + 'x_vectors': x_vectors + }) - if self.num_speakers > 1: - if not self.embeddings_per_sample: - speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] - speaker_embeddings = torch.unsqueeze(speaker_embeddings, 0).transpose(1, 2) - encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) + # set the [alignment] lengths wrt reduction factor for guided attention + if mel_lengths.max() % self.decoder.r != 0: + alignment_lengths = ( + mel_lengths + + (self.decoder.r - + (mel_lengths.max() % self.decoder.r))) // self.decoder.r + else: + alignment_lengths = mel_lengths // self.decoder.r - mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(encoder_outputs) - mel_outputs_postnet = self.postnet(mel_outputs) - mel_outputs_postnet = mel_outputs + mel_outputs_postnet - mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(mel_outputs, mel_outputs_postnet, alignments) - return mel_outputs, mel_outputs_postnet, alignments, stop_tokens + cond_input = {'speaker_ids': speaker_ids, 'x_vectors': x_vectors} + outputs = self.forward(text_input, text_lengths, mel_input, + mel_lengths, cond_input) + + # compute loss + loss_dict = criterion( + outputs['model_outputs'], + outputs['decoder_outputs'], + mel_input, + linear_input, + outputs['stop_tokens'], + stop_targets, + mel_lengths, + outputs['decoder_outputs_backward'], + outputs['alignments'], + alignment_lengths, + outputs['alignments_backward'], + text_lengths, + ) + + # compute alignment error (the lower the better ) + align_error = 1 - alignment_diagonal_score(outputs['alignments']) + loss_dict["align_error"] = align_error + return outputs, loss_dict + + def train_log(self, ap, batch, outputs): + postnet_outputs = outputs['model_outputs'] + alignments = outputs['alignments'] + alignments_backward = outputs['alignments_backward'] + mel_input = batch['mel_input'] + + pred_spec = postnet_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + if self.bidirectional_decoder or self.double_decoder_consistency: + figures["alignment_backward"] = plot_alignment( + alignments_backward[0].data.cpu().numpy(), output_fig=False) + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + return figures, train_audio + + def eval_step(self, batch, criterion): + return self.train_step(batch, criterion) + + def eval_log(self, ap, batch, outputs): + return self.train_log(ap, batch, outputs) diff --git a/TTS/tts/models/tacotron_abstract.py b/TTS/tts/models/tacotron_abstract.py index e684ce7c..2bea06a9 100644 --- a/TTS/tts/models/tacotron_abstract.py +++ b/TTS/tts/models/tacotron_abstract.py @@ -1,10 +1,12 @@ import copy +import logging from abc import ABC, abstractmethod import torch from torch import nn -from TTS.tts.utils.generic_utils import sequence_mask +from TTS.tts.utils.data import sequence_mask +from TTS.utils.training import gradual_training_scheduler class TacotronAbstract(ABC, nn.Module): @@ -35,6 +37,7 @@ class TacotronAbstract(ABC, nn.Module): speaker_embedding_dim=None, use_gst=False, gst=None, + gradual_training=[] ): """Abstract Tacotron class""" super().__init__() @@ -63,6 +66,7 @@ class TacotronAbstract(ABC, nn.Module): self.encoder_in_features = encoder_in_features self.decoder_in_features = decoder_in_features self.speaker_embedding_dim = speaker_embedding_dim + self.gradual_training = gradual_training # layers self.embedding = None @@ -216,3 +220,23 @@ class TacotronAbstract(ABC, nn.Module): speaker_embeddings_ = speaker_embeddings.expand(outputs.size(0), outputs.size(1), -1) outputs = torch.cat([outputs, speaker_embeddings_], dim=-1) return outputs + + ############################# + # CALLBACKS + ############################# + + def on_epoch_start(self, trainer): + """Callback for setting values wrt gradual training schedule. + + Args: + trainer (TrainerTTS): TTS trainer object that is used to train this model. + """ + if self.gradual_training: + r, trainer.config.batch_size = gradual_training_scheduler(trainer.total_steps_done, trainer.config) + trainer.config.r = r + self.decoder.set_r(r) + if trainer.config.bidirectional_decoder: + trainer.model.decoder_backward.set_r(r) + trainer.train_loader = trainer.setup_train_dataloader(self.ap, self.model.decoder.r, verbose=True) + trainer.eval_loader = trainer.setup_eval_dataloder(self.ap, self.model.decoder.r) + logging.info(f"\n > Number of output frames: {self.decoder.r}")