mirror of https://github.com/coqui-ai/TTS.git
input shapes for tacotron models
parent
f288e9a260
commit
14d33662ea
|
@ -134,10 +134,12 @@ class Tacotron(TacotronAbstract):
|
|||
def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None, speaker_embeddings=None):
|
||||
"""
|
||||
Shapes:
|
||||
- characters: B x T_in
|
||||
- text_lengths: B
|
||||
- mel_specs: B x T_out x D
|
||||
- speaker_ids: B x 1
|
||||
characters: [B, T_in]
|
||||
text_lengths: [B]
|
||||
mel_specs: [B, T_out, C]
|
||||
mel_lengths: [B]
|
||||
speaker_ids: [B, 1]
|
||||
speaker_embeddings: [B, C]
|
||||
"""
|
||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||
# B x T_in x embed_dim
|
||||
|
|
|
@ -130,6 +130,15 @@ class Tacotron2(TacotronAbstract):
|
|||
return mel_outputs, mel_outputs_postnet, alignments
|
||||
|
||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None, speaker_embeddings=None):
|
||||
"""
|
||||
Shapes:
|
||||
text: [B, T_in]
|
||||
text_lengths: [B]
|
||||
mel_specs: [B, T_out, C]
|
||||
mel_lengths: [B]
|
||||
speaker_ids: [B, 1]
|
||||
speaker_embeddings: [B, C]
|
||||
"""
|
||||
# compute mask for padding
|
||||
# B x T_in_max (boolean)
|
||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||
|
|
Loading…
Reference in New Issue