from math import sqrt import torch from torch.autograd import Variable from torch import nn from torch.nn import functional as F from layers.tacotron2 import Encoder, Decoder, Postnet from utils.generic_utils import sequence_mask # TODO: match function arguments with tacotron class Tacotron2(nn.Module): def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", prenet_dropout=True, forward_attn=False, trans_agent=False, location_attn=True, separate_stopnet=True): super(Tacotron2, self).__init__() self.n_mel_channels = 80 self.n_frames_per_step = r self.embedding = nn.Embedding(num_chars, 512) std = sqrt(2.0 / (num_chars + 512)) val = sqrt(3.0) * std # uniform bounds for std self.embedding.weight.data.uniform_(-val, val) self.encoder = Encoder(512) self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, location_attn, separate_stopnet) self.postnet = Postnet(self.n_mel_channels) def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments): mel_outputs = mel_outputs.transpose(1, 2) mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2) return mel_outputs, mel_outputs_postnet, alignments def forward(self, text, text_lengths, mel_specs=None): # compute mask for padding mask = sequence_mask(text_lengths).to(text.device) embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder(embedded_inputs, text_lengths) mel_outputs, stop_tokens, alignments = self.decoder( encoder_outputs, mel_specs, mask) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs( mel_outputs, mel_outputs_postnet, alignments) return mel_outputs, mel_outputs_postnet, alignments, stop_tokens def inference(self, text): embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder.inference(embedded_inputs) mel_outputs, stop_tokens, alignments = self.decoder.inference( encoder_outputs) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs( mel_outputs, mel_outputs_postnet, alignments) return mel_outputs, mel_outputs_postnet, alignments, stop_tokens def inference_truncated(self, text): """ Preserve model states for continuous inference """ embedded_inputs = self.embedding(text).transpose(1, 2) encoder_outputs = self.encoder.inference_truncated(embedded_inputs) mel_outputs, stop_tokens, alignments = self.decoder.inference_truncated( encoder_outputs) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs( mel_outputs, mel_outputs_postnet, alignments) return mel_outputs, mel_outputs_postnet, alignments, stop_tokens