From 14d33662ea7ec00da942bef53d571a20dac63804 Mon Sep 17 00:00:00 2001 From: erogol Date: Wed, 6 Jan 2021 13:15:56 +0100 Subject: [PATCH] input shapes for tacotron models --- TTS/tts/models/tacotron.py | 10 ++++++---- TTS/tts/models/tacotron2.py | 9 +++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index 73434aa5..61eea893 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -134,10 +134,12 @@ class Tacotron(TacotronAbstract): def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None, speaker_embeddings=None): """ Shapes: - - characters: B x T_in - - text_lengths: B - - mel_specs: B x T_out x D - - speaker_ids: B x 1 + characters: [B, T_in] + text_lengths: [B] + mel_specs: [B, T_out, C] + mel_lengths: [B] + speaker_ids: [B, 1] + speaker_embeddings: [B, C] """ input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths) # B x T_in x embed_dim diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index 317bdbc8..e56e4ca0 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -130,6 +130,15 @@ class Tacotron2(TacotronAbstract): return mel_outputs, mel_outputs_postnet, alignments def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None, speaker_embeddings=None): + """ + Shapes: + text: [B, T_in] + text_lengths: [B] + mel_specs: [B, T_out, C] + mel_lengths: [B] + speaker_ids: [B, 1] + speaker_embeddings: [B, C] + """ # compute mask for padding # B x T_in_max (boolean) input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)