update tacotron model to return `model_outputs`

pull/602/head
Eren Gölge 2021-05-26 09:54:48 +02:00
parent f09ec7e3a7
commit c9790bee2c
2 changed files with 7 additions and 7 deletions

View File

@ -255,7 +255,7 @@ class Tacotron(TacotronAbstract):
outputs['alignments_backward'] = alignments_backward
outputs['decoder_outputs_backward'] = decoder_outputs_backward
outputs.update({
'postnet_outputs': postnet_outputs,
'model_outputs': postnet_outputs,
'decoder_outputs': decoder_outputs,
'alignments': alignments,
'stop_tokens': stop_tokens
@ -287,7 +287,7 @@ class Tacotron(TacotronAbstract):
postnet_outputs = self.last_linear(postnet_outputs)
decoder_outputs = decoder_outputs.transpose(1, 2)
outputs = {
'postnet_outputs': postnet_outputs,
'model_outputs': postnet_outputs,
'decoder_outputs': decoder_outputs,
'alignments': alignments,
'stop_tokens': stop_tokens
@ -335,7 +335,7 @@ class Tacotron(TacotronAbstract):
# compute loss
loss_dict = criterion(
outputs['postnet_outputs'],
outputs['model_outputs'],
outputs['decoder_outputs'],
mel_input,
linear_input,
@ -355,7 +355,7 @@ class Tacotron(TacotronAbstract):
return outputs, loss_dict
def train_log(self, ap, batch, outputs):
postnet_outputs = outputs['postnet_outputs']
postnet_outputs = outputs['model_outputs']
alignments = outputs['alignments']
alignments_backward = outputs['alignments_backward']
mel_input = batch['mel_input']

View File

@ -233,7 +233,7 @@ class Tacotron2(TacotronAbstract):
outputs['alignments_backward'] = alignments_backward
outputs['decoder_outputs_backward'] = decoder_outputs_backward
outputs.update({
'postnet_outputs': postnet_outputs,
'model_outputs': postnet_outputs,
'decoder_outputs': decoder_outputs,
'alignments': alignments,
'stop_tokens': stop_tokens
@ -254,7 +254,7 @@ class Tacotron2(TacotronAbstract):
x_vector = self.speaker_embedding(cond_input['speaker_ids'])[:, None]
x_vector = torch.unsqueeze(x_vector, 0).transpose(1, 2)
else:
x_vector = cond_input
x_vector = cond_input['x_vectors']
encoder_outputs = self._concat_speaker_embedding(
encoder_outputs, x_vector)
@ -266,7 +266,7 @@ class Tacotron2(TacotronAbstract):
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
decoder_outputs, postnet_outputs, alignments)
outputs = {
'postnet_outputs': postnet_outputs,
'model_outputs': postnet_outputs,
'decoder_outputs': decoder_outputs,
'alignments': alignments,
'stop_tokens': stop_tokens