tacotron parse output bug fix

pull/10/head
Eren Golge 2019-03-06 13:10:54 +01:00
parent 4326582bb1
commit a4474abd83
1 changed files with 3 additions and 3 deletions

View File

@ -376,12 +376,12 @@ class Decoder(nn.Module):
self.attention = inputs.data.new(B, T).zero_()
self.attention_cum = inputs.data.new(B, T).zero_()
def _parse_outputs(self, outputs, stop_tokens, attentions):
def _parse_outputs(self, outputs, attentions, stop_tokens):
# Back to batch first
attentions = torch.stack(attentions).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
return outputs, stop_tokens, attentions
stop_tokens = torch.stack(stop_tokens).transpose(0, 1).squeeze(-1)
return outputs, attentions, stop_tokens
def decode(self,
inputs,