mirror of https://github.com/coqui-ai/TTS.git
Fix glow_tts imports
parent
570d5971be
commit
a89eb12aca
|
@ -7,9 +7,8 @@ from torch.nn import functional as F
|
|||
from TTS.tts.configs import GlowTTSConfig
|
||||
from TTS.tts.layers.glow_tts.decoder import Decoder
|
||||
from TTS.tts.layers.glow_tts.encoder import Encoder
|
||||
from TTS.tts.utils.helpers import generate_path, maximum_path
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.helpers import sequence_mask
|
||||
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
|
||||
from TTS.tts.utils.speakers import get_speaker_manager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
@ -133,7 +132,7 @@ class GlowTTS(BaseTTS):
|
|||
return y_mean, y_log_scale, o_attn_dur
|
||||
|
||||
def forward(
|
||||
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
|
||||
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -185,7 +184,7 @@ class GlowTTS(BaseTTS):
|
|||
|
||||
@torch.no_grad()
|
||||
def inference_with_MAS(
|
||||
self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
|
||||
self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
It's similar to the teacher forcing in Tacotron.
|
||||
|
@ -246,7 +245,7 @@ class GlowTTS(BaseTTS):
|
|||
|
||||
@torch.no_grad()
|
||||
def decoder_inference(
|
||||
self, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
|
||||
self, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -278,7 +277,9 @@ class GlowTTS(BaseTTS):
|
|||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids":None}): # pylint: disable=dangerous-default-value
|
||||
def inference(
|
||||
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
x_lengths = aux_input["x_lengths"]
|
||||
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
||||
|
||||
|
@ -331,7 +332,13 @@ class GlowTTS(BaseTTS):
|
|||
d_vectors = batch["d_vectors"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors, "speaker_ids":speaker_ids})
|
||||
outputs = self.forward(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
|
||||
)
|
||||
|
||||
loss_dict = criterion(
|
||||
outputs["model_outputs"],
|
||||
|
|
Loading…
Reference in New Issue