mirror of https://github.com/coqui-ai/TTS.git
fix for 2 dim memory tensor
parent
e085c4757d
commit
9a2bd7f9af
|
@ -183,7 +183,10 @@ class Decoder(nn.Module):
|
|||
return outputs, stop_tokens, alignments
|
||||
|
||||
def _update_memory(self, memory):
|
||||
return memory[:, :, self.mel_channels * (self.r - 1) :]
|
||||
if len(memory.shape) == 2:
|
||||
return memory[:, self.mel_channels * (self.r - 1) :]
|
||||
else:
|
||||
return memory[:, :, self.mel_channels * (self.r - 1) :]
|
||||
|
||||
def decode(self, memory):
|
||||
query_input = torch.cat((memory, self.context), -1)
|
||||
|
|
Loading…
Reference in New Issue