diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 6e914fd7..0ea8b18f 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -183,7 +183,10 @@ class Decoder(nn.Module): return outputs, stop_tokens, alignments def _update_memory(self, memory): - return memory[:, :, self.mel_channels * (self.r - 1) :] + if len(memory.shape) == 2: + return memory[:, self.mel_channels * (self.r - 1) :] + else: + return memory[:, :, self.mel_channels * (self.r - 1) :] def decode(self, memory): query_input = torch.cat((memory, self.context), -1)