input shapes for tacotron models

pull/10/head
erogol 2021-01-06 13:15:56 +01:00
parent f288e9a260
commit 14d33662ea
2 changed files with 15 additions and 4 deletions

View File

@ -134,10 +134,12 @@ class Tacotron(TacotronAbstract):
def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None, speaker_embeddings=None):
"""
Shapes:
- characters: B x T_in
- text_lengths: B
- mel_specs: B x T_out x D
- speaker_ids: B x 1
characters: [B, T_in]
text_lengths: [B]
mel_specs: [B, T_out, C]
mel_lengths: [B]
speaker_ids: [B, 1]
speaker_embeddings: [B, C]
"""
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
# B x T_in x embed_dim

View File

@ -130,6 +130,15 @@ class Tacotron2(TacotronAbstract):
return mel_outputs, mel_outputs_postnet, alignments
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None, speaker_embeddings=None):
"""
Shapes:
text: [B, T_in]
text_lengths: [B]
mel_specs: [B, T_out, C]
mel_lengths: [B]
speaker_ids: [B, 1]
speaker_embeddings: [B, C]
"""
# compute mask for padding
# B x T_in_max (boolean)
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)