Comment Tacotron2 model

pull/887/head
Eren Gölge 2021-10-20 18:14:04 +00:00
parent 92b6d98443
commit 3da79a4de4
4 changed files with 20 additions and 16 deletions

View File

@ -24,7 +24,7 @@ class Tacotron(BaseTacotron):
a multi-speaker model. Defaults to None.
"""
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None):
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
super().__init__(config)
self.speaker_manager = speaker_manager

View File

@ -1,5 +1,6 @@
# coding: utf-8
from typing import Dict
import torch
from coqpit import Coqpit
from torch import nn
@ -38,7 +39,7 @@ class Tacotron2(BaseTacotron):
Speaker manager for multi-speaker training. Uuse only for multi-speaker training. Defaults to None.
"""
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager=None):
def __init__(self, config: Coqpit, speaker_manager: SpeakerManager = None):
super().__init__(config)
self.speaker_manager = speaker_manager
@ -132,11 +133,11 @@ class Tacotron2(BaseTacotron):
"""Forward pass for training with Teacher Forcing.
Shapes:
text: [B, T_in]
text_lengths: [B]
mel_specs: [B, T_out, C]
mel_lengths: [B]
aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C]
text: :math:`[B, T_in]`
text_lengths: :math:`[B]`
mel_specs: :math:`[B, T_out, C]`
mel_lengths: :math:`[B]`
aux_input: 'speaker_ids': :math:`[B, 1]` and 'd_vectors': :math:`[B, C]`
"""
aux_input = self._format_aux_input(aux_input)
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
@ -199,9 +200,9 @@ class Tacotron2(BaseTacotron):
def inference(self, text, aux_input=None):
"""Forward pass for inference with no Teacher-Forcing.
Shapes:
text: :math:`[B, T_in]`
text_lengths: :math:`[B]`
Shapes:
text: :math:`[B, T_in]`
text_lengths: :math:`[B]`
"""
aux_input = self._format_aux_input(aux_input)
embedded_inputs = self.embedding(text).transpose(1, 2)
@ -236,12 +237,12 @@ class Tacotron2(BaseTacotron):
}
return outputs
def train_step(self, batch, criterion):
def train_step(self, batch:Dict, criterion:torch.nn.Module):
"""A single training step. Forward pass and loss computation.
Args:
batch ([type]): [description]
criterion ([type]): [description]
batch ([Dict]): A dictionary of input tensors.
criterion ([type]): Callable criterion to compute model loss.
"""
text_input = batch["text_input"]
text_lengths = batch["text_lengths"]
@ -296,6 +297,7 @@ class Tacotron2(BaseTacotron):
return outputs, loss_dict
def _create_logs(self, batch, outputs, ap):
"""Create dashboard log information."""
postnet_outputs = outputs["model_outputs"]
alignments = outputs["alignments"]
alignments_backward = outputs["alignments_backward"]
@ -321,6 +323,7 @@ class Tacotron2(BaseTacotron):
def train_log(
self, batch: dict, outputs: dict, logger: "Logger", assets: dict, steps: int
) -> None: # pylint: disable=no-self-use
"""Log training progress."""
ap = assets["audio_processor"]
figures, audios = self._create_logs(batch, outputs, ap)
logger.train_figures(steps, figures)

View File

@ -1,5 +1,4 @@
import math
import os
import random
from dataclasses import dataclass, field
from itertools import chain

View File

@ -23,8 +23,10 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True):
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
# TODO: check if you need AMP disabled
# with torch.cuda.amp.autocast(enabled=False):
mu1_sq = mu1.float().pow(2)
mu2_sq = mu2.float().pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq