mirror of https://github.com/coqui-ai/TTS.git
139 lines
4.9 KiB
Python
139 lines
4.9 KiB
Python
import random
|
|
from module import *
|
|
from text.symbols import symbols
|
|
|
|
|
|
class Encoder(nn.Module):
|
|
"""
|
|
Encoder
|
|
"""
|
|
|
|
def __init__(self, embedding_size, hidden_size):
|
|
"""
|
|
|
|
:param embedding_size: dimension of embedding
|
|
"""
|
|
super(Encoder, self).__init__()
|
|
self.embedding_size = embedding_size
|
|
self.embed = nn.Embedding(len(symbols), embedding_size)
|
|
self.prenet = Prenet(embedding_size, hidden_size * 2, hidden_size)
|
|
self.cbhg = CBHG(hidden_size)
|
|
|
|
def forward(self, input_):
|
|
|
|
input_ = torch.transpose(self.embed(input_), 1, 2)
|
|
prenet = self.prenet.forward(input_)
|
|
memory = self.cbhg.forward(prenet)
|
|
|
|
return memory
|
|
|
|
|
|
class MelDecoder(nn.Module):
|
|
"""
|
|
Decoder
|
|
"""
|
|
|
|
def __init__(self, num_mels, hidden_size, dec_out_per_step,
|
|
teacher_forcing_ratio):
|
|
|
|
super(MelDecoder, self).__init__()
|
|
self.prenet = Prenet(num_mels, hidden_size * 2, hidden_size)
|
|
self.attn_decoder = AttentionDecoder(hidden_size * 2, num_mels,
|
|
dec_out_per_step)
|
|
self.dec_out_per_step = dec_out_per_step
|
|
self.teacher_forcing_ratio = teacher_forcing_ratio
|
|
|
|
def forward(self, decoder_input, memory):
|
|
|
|
# Initialize hidden state of GRUcells
|
|
attn_hidden, gru1_hidden, gru2_hidden = self.attn_decoder.inithidden(
|
|
decoder_input.size()[0])
|
|
outputs = list()
|
|
|
|
# Training phase
|
|
if self.training:
|
|
# Prenet
|
|
dec_input = self.prenet.forward(decoder_input)
|
|
timesteps = dec_input.size()[2] // self.dec_out_per_step
|
|
|
|
# [GO] Frame
|
|
prev_output = dec_input[:, :, 0]
|
|
|
|
for i in range(timesteps):
|
|
prev_output, attn_hidden, gru1_hidden, gru2_hidden = self.attn_decoder.forward(prev_output, memory,
|
|
attn_hidden=attn_hidden,
|
|
gru1_hidden=gru1_hidden,
|
|
gru2_hidden=gru2_hidden)
|
|
|
|
outputs.append(prev_output)
|
|
|
|
if random.random() < self.teacher_forcing_ratio:
|
|
# Get spectrum at rth position
|
|
prev_output = dec_input[:, :, i * self.dec_out_per_step]
|
|
else:
|
|
# Get last output
|
|
prev_output = prev_output[:, :, -1]
|
|
|
|
# Concatenate all mel spectrogram
|
|
outputs = torch.cat(outputs, 2)
|
|
|
|
else:
|
|
# [GO] Frame
|
|
prev_output = decoder_input
|
|
|
|
for i in range(max_iters):
|
|
prev_output = self.prenet.forward(prev_output)
|
|
prev_output = prev_output[:, :, 0]
|
|
prev_output, attn_hidden, gru1_hidden, gru2_hidden = self.attn_decoder.forward(prev_output, memory,
|
|
attn_hidden=attn_hidden,
|
|
gru1_hidden=gru1_hidden,
|
|
gru2_hidden=gru2_hidden)
|
|
outputs.append(prev_output)
|
|
prev_output = prev_output[:, :, -1].unsqueeze(2)
|
|
|
|
outputs = torch.cat(outputs, 2)
|
|
|
|
return outputs
|
|
|
|
|
|
class PostProcessingNet(nn.Module):
|
|
"""
|
|
Post-processing Network
|
|
"""
|
|
|
|
def __init__(self, num_mels, num_freq, hidden_size):
|
|
super(PostProcessingNet, self).__init__()
|
|
self.postcbhg = CBHG(hidden_size,
|
|
K=8,
|
|
projection_size=num_mels,
|
|
is_post=True)
|
|
self.linear = SeqLinear(hidden_size * 2,
|
|
num_freq)
|
|
|
|
def forward(self, input_):
|
|
out = self.postcbhg.forward(input_)
|
|
out = self.linear.forward(torch.transpose(out, 1, 2))
|
|
|
|
return out
|
|
|
|
|
|
class Tacotron(nn.Module):
|
|
"""
|
|
End-to-end Tacotron Network
|
|
"""
|
|
|
|
def __init__(self, embedding_size, hidden_size, num_mels, num_freq,
|
|
dec_out_per_step, teacher_forcing_ratio):
|
|
super(Tacotron, self).__init__()
|
|
self.encoder = Encoder(embedding_size, hidden_size)
|
|
self.decoder1 = MelDecoder(num_mels, hidden_size, dec_out_per_step,
|
|
teacher_forcing_ratio)
|
|
self.decoder2 = PostProcessingNet(num_mels, num_freq, hidden_size)
|
|
|
|
def forward(self, characters, mel_input):
|
|
memory = self.encoder.forward(characters)
|
|
mel_output = self.decoder1.forward(mel_input, memory)
|
|
linear_output = self.decoder2.forward(mel_output)
|
|
|
|
return mel_output, linear_output
|