mirror of https://github.com/coqui-ai/TTS.git
Set tacotron model parameters to adap to common_layers.py - Prenet and Attention
parent
2586be7d33
commit
ba492f43be
|
@ -0,0 +1,83 @@
|
|||
from math import sqrt
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self,
|
||||
in_features,
|
||||
out_features,
|
||||
bias=True,
|
||||
init_gain='linear'):
|
||||
super(Linear, self).__init__()
|
||||
self.linear_layer = torch.nn.Linear(
|
||||
in_features, out_features, bias=bias)
|
||||
self._init_w(init_gain)
|
||||
|
||||
def _init_w(self, init_gain):
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear_layer.weight,
|
||||
gain=torch.nn.init.calculate_gain(init_gain))
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_layer(x)
|
||||
|
||||
|
||||
class LinearBN(nn.Module):
|
||||
def __init__(self,
|
||||
in_features,
|
||||
out_features,
|
||||
bias=True,
|
||||
init_gain='linear'):
|
||||
super(LinearBN, self).__init__()
|
||||
self.linear_layer = torch.nn.Linear(
|
||||
in_features, out_features, bias=bias)
|
||||
self.bn = nn.BatchNorm1d(out_features)
|
||||
self._init_w(init_gain)
|
||||
|
||||
def _init_w(self, init_gain):
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear_layer.weight,
|
||||
gain=torch.nn.init.calculate_gain(init_gain))
|
||||
|
||||
def forward(self, x):
|
||||
out = self.linear_layer(x)
|
||||
if len(out.shape) == 3:
|
||||
out = out.permute(1, 2, 0)
|
||||
out = self.bn(out)
|
||||
if len(out.shape) == 3:
|
||||
out = out.permute(2, 0, 1)
|
||||
return out
|
||||
|
||||
|
||||
class Prenet(nn.Module):
|
||||
def __init__(self,
|
||||
in_features,
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
out_features=[256, 256],
|
||||
bias=True):
|
||||
super(Prenet, self).__init__()
|
||||
self.prenet_type = prenet_type
|
||||
self.prenet_dropout = prenet_dropout
|
||||
in_features = [in_features] + out_features[:-1]
|
||||
if prenet_type == "bn":
|
||||
self.layers = nn.ModuleList([
|
||||
LinearBN(in_size, out_size, bias=bias)
|
||||
for (in_size, out_size) in zip(in_features, out_features)
|
||||
])
|
||||
elif prenet_type == "original":
|
||||
self.layers = nn.ModuleList([
|
||||
Linear(in_size, out_size, bias=bias)
|
||||
for (in_size, out_size) in zip(in_features, out_features)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
for linear in self.layers:
|
||||
if self.prenet_dropout:
|
||||
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training)
|
||||
else:
|
||||
x = F.relu(linear(x))
|
||||
return x
|
|
@ -2,38 +2,39 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from .attention import AttentionRNNCell
|
||||
from .common_layers import Prenet
|
||||
|
||||
|
||||
class Prenet(nn.Module):
|
||||
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
|
||||
It creates as many layers as given by 'out_features'
|
||||
# class Prenet(nn.Module):
|
||||
# r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
|
||||
# It creates as many layers as given by 'out_features'
|
||||
|
||||
Args:
|
||||
in_features (int): size of the input vector
|
||||
out_features (int or list): size of each output sample.
|
||||
If it is a list, for each value, there is created a new layer.
|
||||
"""
|
||||
# Args:
|
||||
# in_features (int): size of the input vector
|
||||
# out_features (int or list): size of each output sample.
|
||||
# If it is a list, for each value, there is created a new layer.
|
||||
# """
|
||||
|
||||
def __init__(self, in_features, out_features=[256, 128]):
|
||||
super(Prenet, self).__init__()
|
||||
in_features = [in_features] + out_features[:-1]
|
||||
self.layers = nn.ModuleList([
|
||||
nn.Linear(in_size, out_size)
|
||||
for (in_size, out_size) in zip(in_features, out_features)
|
||||
])
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
# self.init_layers()
|
||||
# def __init__(self, in_features, out_features=[256, 128]):
|
||||
# super(Prenet, self).__init__()
|
||||
# in_features = [in_features] + out_features[:-1]
|
||||
# self.layers = nn.ModuleList([
|
||||
# nn.Linear(in_size, out_size)
|
||||
# for (in_size, out_size) in zip(in_features, out_features)
|
||||
# ])
|
||||
# self.relu = nn.ReLU()
|
||||
# self.dropout = nn.Dropout(0.5)
|
||||
# # self.init_layers()
|
||||
|
||||
def init_layers(self):
|
||||
for layer in self.layers:
|
||||
torch.nn.init.xavier_uniform_(
|
||||
layer.weight, gain=torch.nn.init.calculate_gain('relu'))
|
||||
# def init_layers(self):
|
||||
# for layer in self.layers:
|
||||
# torch.nn.init.xavier_uniform_(
|
||||
# layer.weight, gain=torch.nn.init.calculate_gain('relu'))
|
||||
|
||||
def forward(self, inputs):
|
||||
for linear in self.layers:
|
||||
inputs = self.dropout(self.relu(linear(inputs)))
|
||||
return inputs
|
||||
# def forward(self, inputs):
|
||||
# for linear in self.layers:
|
||||
# inputs = self.dropout(self.relu(linear(inputs)))
|
||||
# return inputs
|
||||
|
||||
|
||||
class BatchNormConv1d(nn.Module):
|
||||
|
@ -301,8 +302,9 @@ class Decoder(nn.Module):
|
|||
memory_size (int): size of the past window. if <= 0 memory_size = r
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, memory_dim, r, memory_size,
|
||||
attn_windowing, attn_norm, separate_stopnet):
|
||||
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing,
|
||||
attn_norm, prenet_type, prenet_dropout, forward_attn,
|
||||
trans_agent, location_attn, separate_stopnet):
|
||||
super(Decoder, self).__init__()
|
||||
self.r = r
|
||||
self.in_features = in_features
|
||||
|
@ -312,7 +314,10 @@ class Decoder(nn.Module):
|
|||
self.separate_stopnet = separate_stopnet
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
self.prenet = Prenet(
|
||||
memory_dim * self.memory_size, out_features=[256, 128])
|
||||
memory_dim * self.memory_size,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
out_features=[256, 128])
|
||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||
self.attention_rnn = AttentionRNNCell(
|
||||
out_dim=128,
|
||||
|
@ -385,23 +390,22 @@ class Decoder(nn.Module):
|
|||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1).squeeze(-1)
|
||||
return outputs, attentions, stop_tokens
|
||||
|
||||
def decode(self,
|
||||
inputs,
|
||||
t,
|
||||
mask=None):
|
||||
def decode(self, inputs, t, mask=None):
|
||||
# Prenet
|
||||
processed_memory = self.prenet(self.memory_input)
|
||||
# Attention RNN
|
||||
attention_cat = torch.cat(
|
||||
(self.attention.unsqueeze(1), self.attention_cum.unsqueeze(1)), dim=1)
|
||||
(self.attention.unsqueeze(1), self.attention_cum.unsqueeze(1)),
|
||||
dim=1)
|
||||
self.attention_rnn_hidden, self.current_context_vec, self.attention = self.attention_rnn(
|
||||
processed_memory, self.current_context_vec, self.attention_rnn_hidden,
|
||||
inputs, attention_cat, mask, t)
|
||||
processed_memory, self.current_context_vec,
|
||||
self.attention_rnn_hidden, inputs, attention_cat, mask, t)
|
||||
del attention_cat
|
||||
self.attention_cum += self.attention
|
||||
# Concat RNN output and attention context vector
|
||||
decoder_input = self.project_to_decoder_in(
|
||||
torch.cat((self.attention_rnn_hidden, self.current_context_vec), -1))
|
||||
torch.cat((self.attention_rnn_hidden, self.current_context_vec),
|
||||
-1))
|
||||
# Pass through the decoder RNNs
|
||||
for idx in range(len(self.decoder_rnns)):
|
||||
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
|
||||
|
@ -427,7 +431,8 @@ class Decoder(nn.Module):
|
|||
self.memory_input = torch.cat([
|
||||
self.memory_input[:, self.r * self.memory_dim:].clone(),
|
||||
new_memory
|
||||
], dim=-1)
|
||||
],
|
||||
dim=-1)
|
||||
else:
|
||||
self.memory_input = new_memory
|
||||
|
||||
|
|
|
@ -9,23 +9,29 @@ from utils.generic_utils import sequence_mask
|
|||
class Tacotron(nn.Module):
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
r=5,
|
||||
linear_dim=1025,
|
||||
mel_dim=80,
|
||||
r=5,
|
||||
padding_idx=None,
|
||||
memory_size=5,
|
||||
attn_win=False,
|
||||
attn_norm="sigmoid",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
location_attn=True,
|
||||
separate_stopnet=True):
|
||||
super(Tacotron, self).__init__()
|
||||
self.r = r
|
||||
self.mel_dim = mel_dim
|
||||
self.linear_dim = linear_dim
|
||||
self.embedding = nn.Embedding(num_chars, 256, padding_idx=padding_idx)
|
||||
self.embedding = nn.Embedding(num_chars, 256)
|
||||
self.embedding.weight.data.normal_(0, 0.3)
|
||||
self.encoder = Encoder(256)
|
||||
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win,
|
||||
attn_norm, separate_stopnet)
|
||||
attn_norm, prenet_type, prenet_dropout,
|
||||
forward_attn, trans_agent, location_attn,
|
||||
separate_stopnet)
|
||||
self.postnet = PostCBHG(mel_dim)
|
||||
self.last_linear = nn.Sequential(
|
||||
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
|
||||
|
|
|
@ -18,6 +18,11 @@ file_path = os.path.dirname(os.path.realpath(__file__))
|
|||
c = load_config(os.path.join(file_path, 'test_config.json'))
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
r"""Count number of trainable parameters in a network"""
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
class TacotronTrainTest(unittest.TestCase):
|
||||
def test_train_step(self):
|
||||
input = torch.randint(0, 24, (8, 128)).long().to(device)
|
||||
|
@ -33,13 +38,19 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
|
||||
stop_targets = stop_targets.view(input.shape[0],
|
||||
stop_targets.size(1) // c.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
|
||||
stop_targets = (stop_targets.sum(2) >
|
||||
0.0).unsqueeze(2).float().squeeze()
|
||||
|
||||
criterion = L1LossMasked().to(device)
|
||||
criterion_st = nn.BCEWithLogitsLoss().to(device)
|
||||
model = Tacotron(32, c.audio['num_freq'], c.audio['num_mels'],
|
||||
c.r, memory_size=c.memory_size).to(device)
|
||||
model = Tacotron(
|
||||
32,
|
||||
linear_dim=c.audio['num_freq'],
|
||||
mel_dim=c.audio['num_mels'],
|
||||
r=c.r,
|
||||
memory_size=c.memory_size).to(device)
|
||||
model.train()
|
||||
print(" > Num parameters for Tacotron model:%s"%(count_parameters(model)))
|
||||
model_ref = copy.deepcopy(model)
|
||||
count = 0
|
||||
for param, param_ref in zip(model.parameters(),
|
||||
|
|
|
@ -251,9 +251,16 @@ def setup_model(num_chars, c):
|
|||
model = MyModel(
|
||||
num_chars=num_chars,
|
||||
r=c.r,
|
||||
linear_dim=1025,
|
||||
mel_dim=80,
|
||||
memory_size=c.memory_size,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
memory_size=c.memory_size,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
location_attn=c.location_attn,
|
||||
separate_stopnet=c.separate_stopnet)
|
||||
elif c.model.lower() == "tacotron2":
|
||||
model = MyModel(
|
||||
|
|
Loading…
Reference in New Issue