mirror of https://github.com/coqui-ai/TTS.git
Comment Tacotron2 model
parent
92b6d98443
commit
3da79a4de4
|
@ -24,7 +24,7 @@ class Tacotron(BaseTacotron):
|
|||
a multi-speaker model. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None):
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.speaker_manager = speaker_manager
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# coding: utf-8
|
||||
|
||||
from typing import Dict
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
@ -38,7 +39,7 @@ class Tacotron2(BaseTacotron):
|
|||
Speaker manager for multi-speaker training. Uuse only for multi-speaker training. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None):
|
||||
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
|
||||
super().__init__(config)
|
||||
|
||||
self.speaker_manager = speaker_manager
|
||||
|
@ -132,11 +133,11 @@ class Tacotron2(BaseTacotron):
|
|||
"""Forward pass for training with Teacher Forcing.
|
||||
|
||||
Shapes:
|
||||
text: [B, T_in]
|
||||
text_lengths: [B]
|
||||
mel_specs: [B, T_out, C]
|
||||
mel_lengths: [B]
|
||||
aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C]
|
||||
text: :math:`[B, T_in]`
|
||||
text_lengths: :math:`[B]`
|
||||
mel_specs: :math:`[B, T_out, C]`
|
||||
mel_lengths: :math:`[B]`
|
||||
aux_input: 'speaker_ids': :math:`[B, 1]` and 'd_vectors': :math:`[B, C]`
|
||||
"""
|
||||
aux_input = self._format_aux_input(aux_input)
|
||||
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
|
||||
|
@ -199,9 +200,9 @@ class Tacotron2(BaseTacotron):
|
|||
def inference(self, text, aux_input=None):
|
||||
"""Forward pass for inference with no Teacher-Forcing.
|
||||
|
||||
Shapes:
|
||||
text: :math:`[B, T_in]`
|
||||
text_lengths: :math:`[B]`
|
||||
Shapes:
|
||||
text: :math:`[B, T_in]`
|
||||
text_lengths: :math:`[B]`
|
||||
"""
|
||||
aux_input = self._format_aux_input(aux_input)
|
||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
|
@ -236,12 +237,12 @@ class Tacotron2(BaseTacotron):
|
|||
}
|
||||
return outputs
|
||||
|
||||
def train_step(self, batch, criterion):
|
||||
def train_step(self, batch:Dict, criterion:torch.nn.Module):
|
||||
"""A single training step. Forward pass and loss computation.
|
||||
|
||||
Args:
|
||||
batch ([type]): [description]
|
||||
criterion ([type]): [description]
|
||||
batch ([Dict]): A dictionary of input tensors.
|
||||
criterion ([type]): Callable criterion to compute model loss.
|
||||
"""
|
||||
text_input = batch["text_input"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
|
@ -296,6 +297,7 @@ class Tacotron2(BaseTacotron):
|
|||
return outputs, loss_dict
|
||||
|
||||
def _create_logs(self, batch, outputs, ap):
|
||||
"""Create dashboard log information."""
|
||||
postnet_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
alignments_backward = outputs["alignments_backward"]
|
||||
|
@ -321,6 +323,7 @@ class Tacotron2(BaseTacotron):
|
|||
def train_log(
|
||||
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
|
||||
) -> None: # pylint: disable=no-self-use
|
||||
"""Log training progress."""
|
||||
ap = assets["audio_processor"]
|
||||
figures, audios = self._create_logs(batch, outputs, ap)
|
||||
logger.train_figures(steps, figures)
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import math
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
|
|
|
@ -23,8 +23,10 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
|||
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
|
||||
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
# TODO: check if you need AMP disabled
|
||||
# with torch.cuda.amp.autocast(enabled=False):
|
||||
mu1_sq = mu1.float().pow(2)
|
||||
mu2_sq = mu2.float().pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
||||
|
|
Loading…
Reference in New Issue