From f121b0ff5d55c7ad373bd2fe343c1996a0a5c0cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 26 May 2021 16:03:24 +0200 Subject: [PATCH] update `speedy_speech.py` model for trainer --- TTS/tts/models/speedy_speech.py | 139 +++++++++++++++++++++++++++----- 1 file changed, 121 insertions(+), 18 deletions(-) diff --git a/TTS/tts/models/speedy_speech.py b/TTS/tts/models/speedy_speech.py index bc6e912c..daf67b6c 100644 --- a/TTS/tts/models/speedy_speech.py +++ b/TTS/tts/models/speedy_speech.py @@ -3,6 +3,9 @@ from torch import nn from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor +from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.audio import AudioProcessor from TTS.tts.layers.feed_forward.encoder import Encoder from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.glow_tts.monotonic_align import generate_path @@ -46,7 +49,12 @@ class SpeedySpeech(nn.Module): positional_encoding=True, length_scale=1, encoder_type="residual_conv_bn", - encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13}, + encoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 13 + }, decoder_type="residual_conv_bn", decoder_params={ "kernel_size": 4, @@ -60,13 +68,17 @@ class SpeedySpeech(nn.Module): ): super().__init__() - self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale + self.length_scale = float(length_scale) if isinstance( + length_scale, int) else length_scale self.emb = nn.Embedding(num_chars, hidden_channels) - self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels) + self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, + encoder_params, c_in_channels) if positional_encoding: self.pos_encoder = PositionalEncoding(hidden_channels) - self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params) - self.duration_predictor = DurationPredictor(hidden_channels + c_in_channels) + self.decoder = Decoder(out_channels, hidden_channels, decoder_type, + decoder_params) + self.duration_predictor = DurationPredictor(hidden_channels + + c_in_channels) if num_speakers > 1 and not external_c: # speaker embedding layer @@ -93,7 +105,9 @@ class SpeedySpeech(nn.Module): """ attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype) - o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) + o_en_ex = torch.matmul( + attn.squeeze(1).transpose(1, 2), en.transpose(1, + 2)).transpose(1, 2) return o_en_ex, attn def format_durations(self, o_dr_log, x_mask): @@ -127,7 +141,8 @@ class SpeedySpeech(nn.Module): x_emb = torch.transpose(x_emb, 1, -1) # compute sequence masks - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), + 1).to(x.dtype) # encoder pass o_en = self.encoder(x_emb, x_mask) @@ -140,7 +155,8 @@ class SpeedySpeech(nn.Module): return o_en, o_en_dp, x_mask, g def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), + 1).to(o_en_dp.dtype) # expand o_en with durations o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) # positional encoding @@ -153,8 +169,17 @@ class SpeedySpeech(nn.Module): o_de = self.decoder(o_en_ex, y_mask, g=g) return o_de, attn.transpose(1, 2) - def forward(self, x, x_lengths, y_lengths, dr, g=None): # pylint: disable=unused-argument + def forward(self, + x, + x_lengths, + y_lengths, + dr, + cond_input={ + 'x_vectors': None, + 'speaker_ids': None + }): # pylint: disable=unused-argument """ + TODO: speaker embedding for speaker_ids Shapes: x: [B, T_max] x_lengths: [B] @@ -162,35 +187,113 @@ class SpeedySpeech(nn.Module): dr: [B, T_max] g: [B, C] """ + g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) - o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g) - return o_de, o_dr_log.squeeze(1), attn + o_de, attn = self._forward_decoder(o_en, + o_en_dp, + dr, + x_mask, + y_lengths, + g=g) + outputs = { + 'model_outputs': o_de.transpose(1, 2), + 'durations_log': o_dr_log.squeeze(1), + 'alignments': attn + } + return outputs - def inference(self, x, x_lengths, g=None): # pylint: disable=unused-argument + def inference(self, + x, + cond_input={ + 'x_vectors': None, + 'speaker_ids': None + }): # pylint: disable=unused-argument """ Shapes: x: [B, T_max] x_lengths: [B] g: [B, C] """ + g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None + x_lengths = torch.tensor(x.shape[1:2]).to(x.device) # input sequence should be greated than the max convolution size inference_padding = 5 if x.shape[1] < 13: inference_padding += 13 - x.shape[1] # pad input to prevent dropping the last word - x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0) + x = torch.nn.functional.pad(x, + pad=(0, inference_padding), + mode="constant", + value=0) o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) # duration predictor pass o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) y_lengths = o_dr.sum(1) - o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) - return o_de, attn + o_de, attn = self._forward_decoder(o_en, + o_en_dp, + o_dr, + x_mask, + y_lengths, + g=g) + outputs = { + 'model_outputs': o_de.transpose(1, 2), + 'alignments': attn, + 'durations_log': None + } + return outputs - def load_checkpoint( - self, config, checkpoint_path, eval=False - ): # pylint: disable=unused-argument, redefined-builtin + def train_step(self, batch: dict, criterion: nn.Module): + text_input = batch['text_input'] + text_lengths = batch['text_lengths'] + mel_input = batch['mel_input'] + mel_lengths = batch['mel_lengths'] + x_vectors = batch['x_vectors'] + speaker_ids = batch['speaker_ids'] + durations = batch['durations'] + + cond_input = {'x_vectors': x_vectors, 'speaker_ids': speaker_ids} + outputs = self.forward(text_input, text_lengths, mel_lengths, + durations, cond_input) + + # compute loss + loss_dict = criterion(outputs['model_outputs'], mel_input, + mel_lengths, outputs['durations_log'], + torch.log(1 + durations), text_lengths) + + # compute alignment error (the lower the better ) + align_error = 1 - alignment_diagonal_score(outputs['alignments'], + binary=True) + loss_dict["align_error"] = align_error + return outputs, loss_dict + + def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): + model_outputs = outputs['model_outputs'] + alignments = outputs['alignments'] + mel_input = batch['mel_input'] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + return figures, train_audio + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict): + return self.train_log(ap, batch, outputs) + + def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin state = torch.load(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) if eval: