From 9a2bd7f9af6456abbb452eba0260fb3b67312405 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Wed, 18 Sep 2019 02:51:56 +0200 Subject: [PATCH] fix for 2 dim memory tensor --- layers/tacotron2.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 6e914fd7..0ea8b18f 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -183,7 +183,10 @@ class Decoder(nn.Module): return outputs, stop_tokens, alignments def _update_memory(self, memory): - return memory[:, :, self.mel_channels * (self.r - 1) :] + if len(memory.shape) == 2: + return memory[:, self.mel_channels * (self.r - 1) :] + else: + return memory[:, :, self.mel_channels * (self.r - 1) :] def decode(self, memory): query_input = torch.cat((memory, self.context), -1)