mirror of https://github.com/coqui-ai/TTS.git
fix glow-tts inference and forward functions for handling `cond_input`
and refactor its testpull/602/head
parent
d4b1acfa81
commit
223502d827
|
@ -154,10 +154,10 @@ class GlowTTS(nn.Module):
|
|||
y_lengths: B
|
||||
g: [B, C] or B
|
||||
"""
|
||||
y_max_length = y.size(2)
|
||||
y = y.transpose(1, 2)
|
||||
y_max_length = y.size(2)
|
||||
# norm speaker embeddings
|
||||
g = cond_input["x_vectors"]
|
||||
g = cond_input["x_vectors"] if cond_input is not None and "x_vectors" in cond_input else None
|
||||
if g is not None:
|
||||
if self.speaker_embedding_dim:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
|
@ -196,19 +196,23 @@ class GlowTTS(nn.Module):
|
|||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference_with_MAS(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None):
|
||||
def inference_with_MAS(
|
||||
self, x, x_lengths, y=None, y_lengths=None, cond_input={"x_vectors": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
It's similar to the teacher forcing in Tacotron.
|
||||
It was proposed in: https://arxiv.org/abs/2104.05557
|
||||
Shapes:
|
||||
x: [B, T]
|
||||
x_lenghts: B
|
||||
y: [B, C, T]
|
||||
y: [B, T, C]
|
||||
y_lengths: B
|
||||
g: [B, C] or B
|
||||
"""
|
||||
y = y.transpose(1, 2)
|
||||
y_max_length = y.size(2)
|
||||
# norm speaker embeddings
|
||||
g = cond_input["x_vectors"] if cond_input is not None and "x_vectors" in cond_input else None
|
||||
if g is not None:
|
||||
if self.external_speaker_embedding_dim:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
|
@ -253,14 +257,18 @@ class GlowTTS(nn.Module):
|
|||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def decoder_inference(self, y, y_lengths=None, g=None):
|
||||
def decoder_inference(
|
||||
self, y, y_lengths=None, cond_input={"x_vectors": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Shapes:
|
||||
y: [B, C, T]
|
||||
y: [B, T, C]
|
||||
y_lengths: B
|
||||
g: [B, C] or B
|
||||
"""
|
||||
y = y.transpose(1, 2)
|
||||
y_max_length = y.size(2)
|
||||
g = cond_input["x_vectors"] if cond_input is not None and "x_vectors" in cond_input else None
|
||||
# norm speaker embeddings
|
||||
if g is not None:
|
||||
if self.external_speaker_embedding_dim:
|
||||
|
@ -276,10 +284,14 @@ class GlowTTS(nn.Module):
|
|||
# reverse decoder and predict
|
||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||
|
||||
return y, logdet
|
||||
outputs = {}
|
||||
outputs["model_outputs"] = y
|
||||
outputs["logdet"] = logdet
|
||||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, x_lengths, g=None):
|
||||
def inference(self, x, x_lengths, cond_input={"x_vectors": None}): # pylint: disable=dangerous-default-value
|
||||
g = cond_input["x_vectors"] if cond_input is not None and "x_vectors" in cond_input else None
|
||||
if g is not None:
|
||||
if self.speaker_embedding_dim:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
|
|
|
@ -34,7 +34,7 @@ class GlowTTSTrainTest(unittest.TestCase):
|
|||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||
input_lengths[-1] = 128
|
||||
mel_spec = torch.rand(8, c.audio["num_mels"], 30).to(device)
|
||||
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
||||
|
||||
|
@ -114,10 +114,17 @@ class GlowTTSTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
||||
for _ in range(5):
|
||||
optimizer.zero_grad()
|
||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, None
|
||||
outputs = model.forward(input_dummy, input_lengths, mel_spec, mel_lengths, None)
|
||||
loss_dict = criterion(
|
||||
outputs["model_outputs"],
|
||||
outputs["y_mean"],
|
||||
outputs["y_log_scale"],
|
||||
outputs["logdet"],
|
||||
mel_lengths,
|
||||
outputs["durations_log"],
|
||||
outputs["total_durations_log"],
|
||||
input_lengths,
|
||||
)
|
||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, o_dur_log, o_total_dur, input_lengths)
|
||||
loss = loss_dict["loss"]
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
@ -137,7 +144,7 @@ class GlowTTSInferenceTest(unittest.TestCase):
|
|||
input_dummy = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
input_lengths = torch.randint(100, 129, (8,)).long().to(device)
|
||||
input_lengths[-1] = 128
|
||||
mel_spec = torch.rand(8, c.audio["num_mels"], 30).to(device)
|
||||
mel_spec = torch.rand(8, 30, c.audio["num_mels"]).to(device)
|
||||
mel_lengths = torch.randint(20, 30, (8,)).long().to(device)
|
||||
speaker_ids = torch.randint(0, 5, (8,)).long().to(device)
|
||||
|
||||
|
@ -175,12 +182,12 @@ class GlowTTSInferenceTest(unittest.TestCase):
|
|||
print(" > Num parameters for GlowTTS model:%s" % (count_parameters(model)))
|
||||
|
||||
# inference encoder and decoder with MAS
|
||||
y, *_ = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths, None)
|
||||
y = model.inference_with_MAS(input_dummy, input_lengths, mel_spec, mel_lengths)
|
||||
|
||||
y_dec, _ = model.decoder_inference(mel_spec, mel_lengths)
|
||||
y2 = model.decoder_inference(mel_spec, mel_lengths)
|
||||
|
||||
assert (
|
||||
y_dec.shape == y.shape
|
||||
y2["model_outputs"].shape == y["model_outputs"].shape
|
||||
), "Difference between the shapes of the glowTTS inference with MAS ({}) and the inference using only the decoder ({}) !!".format(
|
||||
y.shape, y_dec.shape
|
||||
y["model_outputs"].shape, y2["model_outputs"].shape
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue