diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 9f20f6bb..2c944008 100755 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -154,10 +154,10 @@ class GlowTTS(nn.Module): y_lengths: B g: [B, C] or B """ - y_max_length = y.size(2) y = y.transpose(1, 2) + y_max_length = y.size(2) # norm speaker embeddings - g = cond_input["x_vectors"] + g = cond_input["x_vectors"] if cond_input is not None and "x_vectors" in cond_input else None if g is not None: if self.speaker_embedding_dim: g = F.normalize(g).unsqueeze(-1) @@ -196,19 +196,23 @@ class GlowTTS(nn.Module): return outputs @torch.no_grad() - def inference_with_MAS(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None): + def inference_with_MAS( + self, x, x_lengths, y=None, y_lengths=None, cond_input={"x_vectors": None} + ): # pylint: disable=dangerous-default-value """ It's similar to the teacher forcing in Tacotron. It was proposed in: https://arxiv.org/abs/2104.05557 Shapes: x: [B, T] x_lenghts: B - y: [B, C, T] + y: [B, T, C] y_lengths: B g: [B, C] or B """ + y = y.transpose(1, 2) y_max_length = y.size(2) # norm speaker embeddings + g = cond_input["x_vectors"] if cond_input is not None and "x_vectors" in cond_input else None if g is not None: if self.external_speaker_embedding_dim: g = F.normalize(g).unsqueeze(-1) @@ -253,14 +257,18 @@ class GlowTTS(nn.Module): return outputs @torch.no_grad() - def decoder_inference(self, y, y_lengths=None, g=None): + def decoder_inference( + self, y, y_lengths=None, cond_input={"x_vectors": None} + ): # pylint: disable=dangerous-default-value """ Shapes: - y: [B, C, T] + y: [B, T, C] y_lengths: B g: [B, C] or B """ + y = y.transpose(1, 2) y_max_length = y.size(2) + g = cond_input["x_vectors"] if cond_input is not None and "x_vectors" in cond_input else None # norm speaker embeddings if g is not None: if self.external_speaker_embedding_dim: @@ -276,10 +284,14 @@ class GlowTTS(nn.Module): # reverse decoder and predict y, logdet = self.decoder(z, y_mask, g=g, reverse=True) - return y, logdet + outputs = {} + outputs["model_outputs"] = y + outputs["logdet"] = logdet + return outputs @torch.no_grad() - def inference(self, x, x_lengths, g=None): + def inference(self, x, x_lengths, cond_input={"x_vectors": None}): # pylint: disable=dangerous-default-value + g = cond_input["x_vectors"] if cond_input is not None and "x_vectors" in cond_input else None if g is not None: if self.speaker_embedding_dim: g = F.normalize(g).unsqueeze(-1) diff --git a/tests/tts_tests/test_glow_tts.py b/tests/tts_tests/test_glow_tts.py index 486de274..8a2a8fb3 100644 --- a/tests/tts_tests/test_glow_tts.py +++ b/tests/tts_tests/test_glow_tts.py @@ -34,7 +34,7 @@ class GlowTTSTrainTest(unittest.TestCase): input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 129, (8,)).long().to(device) input_lengths[-1] = 128 - mel_spec = torch.rand(8, c.audio["num_mels"], 30).to(device) + mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) mel_lengths = torch.randint(20, 30, (8,)).long().to(device) speaker_ids = torch.randint(0, 5, (8,)).long().to(device) @@ -114,10 +114,17 @@ class GlowTTSTrainTest(unittest.TestCase): optimizer = optim.Adam(model.parameters(), lr=0.001) for _ in range(5): optimizer.zero_grad() - z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( - input_dummy, input_lengths, mel_spec, mel_lengths, None + outputs = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, None) + loss_dict = criterion( + outputs["model_outputs"], + outputs["y_mean"], + outputs["y_log_scale"], + outputs["logdet"], + mel_lengths, + outputs["durations_log"], + outputs["total_durations_log"], + input_lengths, ) - loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, o_dur_log, o_total_dur, input_lengths) loss = loss_dict["loss"] loss.backward() optimizer.step() @@ -137,7 +144,7 @@ class GlowTTSInferenceTest(unittest.TestCase): input_dummy = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 129, (8,)).long().to(device) input_lengths[-1] = 128 - mel_spec = torch.rand(8, c.audio["num_mels"], 30).to(device) + mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device) mel_lengths = torch.randint(20, 30, (8,)).long().to(device) speaker_ids = torch.randint(0, 5, (8,)).long().to(device) @@ -175,12 +182,12 @@ class GlowTTSInferenceTest(unittest.TestCase): print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model))) # inference encoder and decoder with MAS - y, *_ = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths, None) + y = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths) - y_dec, _ = model.decoder_inference(mel_spec, mel_lengths) + y2 = model.decoder_inference(mel_spec, mel_lengths) assert ( - y_dec.shape == y.shape + y2["model_outputs"].shape == y["model_outputs"].shape ), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format( - y.shape, y_dec.shape + y["model_outputs"].shape, y2["model_outputs"].shape )