TTS/network.py

139 lines
4.9 KiB
Python

import random
from module import *
from text.symbols import symbols
class Encoder(nn.Module):
"""
Encoder
"""
def __init__(self, embedding_size, hidden_size):
"""
:param embedding_size: dimension of embedding
"""
super(Encoder, self).__init__()
self.embedding_size = embedding_size
self.embed = nn.Embedding(len(symbols), embedding_size)
self.prenet = Prenet(embedding_size, hidden_size * 2, hidden_size)
self.cbhg = CBHG(hidden_size)
def forward(self, input_):
input_ = torch.transpose(self.embed(input_), 1, 2)
prenet = self.prenet.forward(input_)
memory = self.cbhg.forward(prenet)
return memory
class MelDecoder(nn.Module):
"""
Decoder
"""
def __init__(self, num_mels, hidden_size, dec_out_per_step,
teacher_forcing_ratio):
super(MelDecoder, self).__init__()
self.prenet = Prenet(num_mels, hidden_size * 2, hidden_size)
self.attn_decoder = AttentionDecoder(hidden_size * 2, num_mels,
dec_out_per_step)
self.dec_out_per_step = dec_out_per_step
self.teacher_forcing_ratio = teacher_forcing_ratio
def forward(self, decoder_input, memory):
# Initialize hidden state of GRUcells
attn_hidden, gru1_hidden, gru2_hidden = self.attn_decoder.inithidden(
decoder_input.size()[0])
outputs = list()
# Training phase
if self.training:
# Prenet
dec_input = self.prenet.forward(decoder_input)
timesteps = dec_input.size()[2] // self.dec_out_per_step
# [GO] Frame
prev_output = dec_input[:, :, 0]
for i in range(timesteps):
prev_output, attn_hidden, gru1_hidden, gru2_hidden = self.attn_decoder.forward(prev_output, memory,
attn_hidden=attn_hidden,
gru1_hidden=gru1_hidden,
gru2_hidden=gru2_hidden)
outputs.append(prev_output)
if random.random() < self.teacher_forcing_ratio:
# Get spectrum at rth position
prev_output = dec_input[:, :, i * self.dec_out_per_step]
else:
# Get last output
prev_output = prev_output[:, :, -1]
# Concatenate all mel spectrogram
outputs = torch.cat(outputs, 2)
else:
# [GO] Frame
prev_output = decoder_input
for i in range(max_iters):
prev_output = self.prenet.forward(prev_output)
prev_output = prev_output[:, :, 0]
prev_output, attn_hidden, gru1_hidden, gru2_hidden = self.attn_decoder.forward(prev_output, memory,
attn_hidden=attn_hidden,
gru1_hidden=gru1_hidden,
gru2_hidden=gru2_hidden)
outputs.append(prev_output)
prev_output = prev_output[:, :, -1].unsqueeze(2)
outputs = torch.cat(outputs, 2)
return outputs
class PostProcessingNet(nn.Module):
"""
Post-processing Network
"""
def __init__(self, num_mels, num_freq, hidden_size):
super(PostProcessingNet, self).__init__()
self.postcbhg = CBHG(hidden_size,
K=8,
projection_size=num_mels,
is_post=True)
self.linear = SeqLinear(hidden_size * 2,
num_freq)
def forward(self, input_):
out = self.postcbhg.forward(input_)
out = self.linear.forward(torch.transpose(out, 1, 2))
return out
class Tacotron(nn.Module):
"""
End-to-end Tacotron Network
"""
def __init__(self, embedding_size, hidden_size, num_mels, num_freq,
dec_out_per_step, teacher_forcing_ratio):
super(Tacotron, self).__init__()
self.encoder = Encoder(embedding_size, hidden_size)
self.decoder1 = MelDecoder(num_mels, hidden_size, dec_out_per_step,
teacher_forcing_ratio)
self.decoder2 = PostProcessingNet(num_mels, num_freq, hidden_size)
def forward(self, characters, mel_input):
memory = self.encoder.forward(characters)
mel_output = self.decoder1.forward(mel_input, memory)
linear_output = self.decoder2.forward(mel_output)
return mel_output, linear_output