Fix glow_tts imports

pull/800/head
Eren Gölge 2021-09-10 08:29:51 +00:00
parent 570d5971be
commit a89eb12aca
1 changed files with 14 additions and 7 deletions

View File

@ -7,9 +7,8 @@ from torch.nn import functional as F
from TTS.tts.configs import GlowTTSConfig
from TTS.tts.layers.glow_tts.decoder import Decoder
from TTS.tts.layers.glow_tts.encoder import Encoder
from TTS.tts.utils.helpers import generate_path, maximum_path
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.helpers import generate_path, maximum_path, sequence_mask
from TTS.tts.utils.speakers import get_speaker_manager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
@ -133,7 +132,7 @@ class GlowTTS(BaseTTS):
return y_mean, y_log_scale, o_attn_dur
def forward(
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value
"""
Shapes:
@ -185,7 +184,7 @@ class GlowTTS(BaseTTS):
@torch.no_grad()
def inference_with_MAS(
self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value
"""
It's similar to the teacher forcing in Tacotron.
@ -246,7 +245,7 @@ class GlowTTS(BaseTTS):
@torch.no_grad()
def decoder_inference(
self, y, y_lengths=None, aux_input={"d_vectors": None, 'speaker_ids':None}
self, y, y_lengths=None, aux_input={"d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value
"""
Shapes:
@ -278,7 +277,9 @@ class GlowTTS(BaseTTS):
return outputs
@torch.no_grad()
def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids":None}): # pylint: disable=dangerous-default-value
def inference(
self, x, aux_input={"x_lengths": None, "d_vectors": None, "speaker_ids": None}
): # pylint: disable=dangerous-default-value
x_lengths = aux_input["x_lengths"]
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
@ -331,7 +332,13 @@ class GlowTTS(BaseTTS):
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors, "speaker_ids":speaker_ids})
outputs = self.forward(
text_input,
text_lengths,
mel_input,
mel_lengths,
aux_input={"d_vectors": d_vectors, "speaker_ids": speaker_ids},
)
loss_dict = criterion(
outputs["model_outputs"],