fix for 2 dim memory tensor

pull/10/head
Eren Golge 2019-09-18 02:51:56 +02:00
parent e085c4757d
commit 9a2bd7f9af
1 changed files with 4 additions and 1 deletions

View File

@ -183,6 +183,9 @@ class Decoder(nn.Module):
return outputs, stop_tokens, alignments
def _update_memory(self, memory):
if len(memory.shape) == 2:
return memory[:, self.mel_channels * (self.r - 1) :]
else:
return memory[:, :, self.mel_channels * (self.r - 1) :]
def decode(self, memory):