pull/10/head
Eren Golge 2018-04-30 06:18:03 -07:00
parent 754e0d3b63
commit d2657cbf3a
2 changed files with 2 additions and 2 deletions

View File

@ -128,7 +128,7 @@ class LJSpeechDataset(Dataset):
linear = torch.FloatTensor(linear)
mel = torch.FloatTensor(mel)
mel_lengths = torch.LongTensor(mel_lengths)
stop_targets = torch.FloatTensor(stop_targets).squeeze()
stop_targets = torch.FloatTensor(stop_targets)
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]

View File

@ -33,7 +33,7 @@ class Tacotron(nn.Module):
# Reshape
# batch x time x dim
mel_outputs = mel_outputs.view(B, -1, self.mel_dim)
stop_tokens = self.stopnet(mel_outputs)
stop_tokens = self.stopnet(mel_outputs).squeeze()
linear_outputs = self.postnet(mel_outputs)
linear_outputs = self.last_linear(linear_outputs)
return mel_outputs, linear_outputs, alignments, stop_tokens