diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index ca059ab9..09e58ce7 100755 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -38,7 +38,6 @@ class GlowTTS(nn.Module): encoder_params (dict): encoder module parameters. speaker_embedding_dim (int): channels of external speaker embedding vectors. """ - def __init__( self, num_chars, @@ -133,27 +132,29 @@ class GlowTTS(nn.Module): @staticmethod def compute_outputs(attn, o_mean, o_log_scale, x_mask): # compute final values with the computed alignment - y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose( - 1, 2 - ) # [b, t', t], [b, t, d] -> [b, d, t'] - y_log_scale = torch.matmul(attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(1, 2)).transpose( - 1, 2 - ) # [b, t', t], [b, t, d] -> [b, d, t'] + y_mean = torch.matmul( + attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose( + 1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + y_log_scale = torch.matmul( + attn.squeeze(1).transpose(1, 2), o_log_scale.transpose( + 1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] # compute total duration with adjustment o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask return y_mean, y_log_scale, o_attn_dur - def forward(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None): + def forward(self, x, x_lengths, y, y_lengths=None, cond_input={'x_vectors':None}): """ Shapes: x: [B, T] x_lenghts: B - y: [B, C, T] + y: [B, T, C] y_lengths: B g: [B, C] or B """ y_max_length = y.size(2) + y = y.transpose(1, 2) # norm speaker embeddings + g = cond_input['x_vectors'] if g is not None: if self.speaker_embedding_dim: g = F.normalize(g).unsqueeze(-1) @@ -161,29 +162,54 @@ class GlowTTS(nn.Module): g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] # embedding pass - o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, + x_lengths, + g=g) # drop redisual frames wrt num_squeeze and set y_lengths. - y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None) + y, y_lengths, y_max_length, attn = self.preprocess( + y, y_lengths, y_max_length, None) # create masks - y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), + 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # decoder pass z, logdet = self.decoder(y, y_mask, g=g, reverse=False) # find the alignment path with torch.no_grad(): o_scale = torch.exp(-2 * o_log_scale) - logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] - logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] - logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] - logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, + [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * + (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), + z) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, + [1]).unsqueeze(-1) # [b, t, 1] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] - attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() - y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) + attn = maximum_path(logp, + attn_mask.squeeze(1)).unsqueeze(1).detach() + y_mean, y_log_scale, o_attn_dur = self.compute_outputs( + attn, o_mean, o_log_scale, x_mask) attn = attn.squeeze(1).permute(0, 2, 1) - return z, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur + outputs = { + 'model_outputs': z, + 'logdet': logdet, + 'y_mean': y_mean, + 'y_log_scale': y_log_scale, + 'alignments': attn, + 'durations_log': o_dur_log, + 'total_durations_log': o_attn_dur + } + return outputs @torch.no_grad() - def inference_with_MAS(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None): + def inference_with_MAS(self, + x, + x_lengths, + y=None, + y_lengths=None, + attn=None, + g=None): """ It's similar to the teacher forcing in Tacotron. It was proposed in: https://arxiv.org/abs/2104.05557 @@ -203,24 +229,33 @@ class GlowTTS(nn.Module): g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] # embedding pass - o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, + x_lengths, + g=g) # drop redisual frames wrt num_squeeze and set y_lengths. - y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None) + y, y_lengths, y_max_length, attn = self.preprocess( + y, y_lengths, y_max_length, None) # create masks - y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), + 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # decoder pass z, logdet = self.decoder(y, y_mask, g=g, reverse=False) # find the alignment path between z and encoder output o_scale = torch.exp(-2 * o_log_scale) - logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] - logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] - logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] - logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, + [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * + (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), + z) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, + [1]).unsqueeze(-1) # [b, t, 1] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() - y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) + y_mean, y_log_scale, o_attn_dur = self.compute_outputs( + attn, o_mean, o_log_scale, x_mask) attn = attn.squeeze(1).permute(0, 2, 1) # get predited aligned distribution @@ -228,8 +263,16 @@ class GlowTTS(nn.Module): # reverse the decoder and predict using the aligned distribution y, logdet = self.decoder(z, y_mask, g=g, reverse=True) - - return y, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur + outputs = { + 'model_outputs': y, + 'logdet': logdet, + 'y_mean': y_mean, + 'y_log_scale': y_log_scale, + 'alignments': attn, + 'durations_log': o_dur_log, + 'total_durations_log': o_attn_dur + } + return outputs @torch.no_grad() def decoder_inference(self, y, y_lengths=None, g=None): @@ -247,7 +290,8 @@ class GlowTTS(nn.Module): else: g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] - y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(y.dtype) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), + 1).to(y.dtype) # decoder pass z, logdet = self.decoder(y, y_mask, g=g, reverse=False) @@ -266,28 +310,98 @@ class GlowTTS(nn.Module): g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] # embedding pass - o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, + x_lengths, + g=g) # compute output durations w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale w_ceil = torch.ceil(w) y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_max_length = None # compute masks - y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), + 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # compute attention mask - attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) - y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) + attn = generate_path(w_ceil.squeeze(1), + attn_mask.squeeze(1)).unsqueeze(1) + y_mean, y_log_scale, o_attn_dur = self.compute_outputs( + attn, o_mean, o_log_scale, x_mask) - z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.inference_noise_scale) * y_mask + z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * + self.inference_noise_scale) * y_mask # decoder pass y, logdet = self.decoder(z, y_mask, g=g, reverse=True) attn = attn.squeeze(1).permute(0, 2, 1) - return y, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur + outputs = { + 'model_outputs': y, + 'logdet': logdet, + 'y_mean': y_mean, + 'y_log_scale': y_log_scale, + 'alignments': attn, + 'durations_log': o_dur_log, + 'total_durations_log': o_attn_dur + } + return outputs + + def train_step(self, batch: dict, criterion: nn.Module): + """Perform a single training step by fetching the right set if samples from the batch. + + Args: + batch (dict): [description] + criterion (nn.Module): [description] + """ + text_input = batch['text_input'] + text_lengths = batch['text_lengths'] + mel_input = batch['mel_input'] + mel_lengths = batch['mel_lengths'] + x_vectors = batch['x_vectors'] + + outputs = self.forward(text_input, + text_lengths, + mel_input, + mel_lengths, + cond_input={"x_vectors": x_vectors}) + + loss_dict = criterion(outputs['model_outputs'], outputs['y_mean'], + outputs['y_log_scale'], outputs['logdet'], + mel_lengths, outputs['durations_log'], + outputs['total_durations_log'], text_lengths) + + # compute alignment error (the lower the better ) + align_error = 1 - alignment_diagonal_score(outputs['alignments'], binary=True) + loss_dict["align_error"] = align_error + return outputs, loss_dict + + def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): + model_outputs = outputs['model_outputs'] + alignments = outputs['alignments'] + mel_input = batch['mel_input'] + + pred_spec = model_outputs[0].data.cpu().numpy() + gt_spec = mel_input[0].data.cpu().numpy() + align_img = alignments[0].data.cpu().numpy() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), + } + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + return figures, train_audio + + def eval_step(self, batch: dict, criterion: nn.Module): + return self.train_step(batch, criterion) + + def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict): + return self.train_log(ap, batch, outputs) def preprocess(self, y, y_lengths, y_max_length, attn=None): if y_max_length is not None: - y_max_length = (y_max_length // self.num_squeeze) * self.num_squeeze + y_max_length = (y_max_length // + self.num_squeeze) * self.num_squeeze y = y[:, :, :y_max_length] if attn is not None: attn = attn[:, :, :, :y_max_length] @@ -297,9 +411,7 @@ class GlowTTS(nn.Module): def store_inverse(self): self.decoder.store_inverse() - def load_checkpoint( - self, config, checkpoint_path, eval=False - ): # pylint: disable=unused-argument, redefined-builtin + def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin state = torch.load(checkpoint_path, map_location=torch.device("cpu")) self.load_state_dict(state["model"]) if eval: