mirror of https://github.com/coqui-ai/TTS.git
input shapes for tacotron models
parent
f288e9a260
commit
14d33662ea
|
@ -134,10 +134,12 @@ class Tacotron(TacotronAbstract):
|
||||||
def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None, speaker_embeddings=None):
|
def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None, speaker_embeddings=None):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
- characters: B x T_in
|
characters: [B, T_in]
|
||||||
- text_lengths: B
|
text_lengths: [B]
|
||||||
- mel_specs: B x T_out x D
|
mel_specs: [B, T_out, C]
|
||||||
- speaker_ids: B x 1
|
mel_lengths: [B]
|
||||||
|
speaker_ids: [B, 1]
|
||||||
|
speaker_embeddings: [B, C]
|
||||||
"""
|
"""
|
||||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||||
# B x T_in x embed_dim
|
# B x T_in x embed_dim
|
||||||
|
|
|
@ -130,6 +130,15 @@ class Tacotron2(TacotronAbstract):
|
||||||
return mel_outputs, mel_outputs_postnet, alignments
|
return mel_outputs, mel_outputs_postnet, alignments
|
||||||
|
|
||||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None, speaker_embeddings=None):
|
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None, speaker_embeddings=None):
|
||||||
|
"""
|
||||||
|
Shapes:
|
||||||
|
text: [B, T_in]
|
||||||
|
text_lengths: [B]
|
||||||
|
mel_specs: [B, T_out, C]
|
||||||
|
mel_lengths: [B]
|
||||||
|
speaker_ids: [B, 1]
|
||||||
|
speaker_embeddings: [B, C]
|
||||||
|
"""
|
||||||
# compute mask for padding
|
# compute mask for padding
|
||||||
# B x T_in_max (boolean)
|
# B x T_in_max (boolean)
|
||||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||||
|
|
Loading…
Reference in New Issue