Set tacotron model parameters to adap to common_layers.py - Prenet and Attention

pull/10/head
Eren Golge 2019-05-27 14:40:28 +02:00
parent 2586be7d33
commit ba492f43be
5 changed files with 158 additions and 46 deletions

83
layers/common_layers.py Normal file
View File

@ -0,0 +1,83 @@
from math import sqrt
import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F
class Linear(nn.Module):
def __init__(self,
in_features,
out_features,
bias=True,
init_gain='linear'):
super(Linear, self).__init__()
self.linear_layer = torch.nn.Linear(
in_features, out_features, bias=bias)
self._init_w(init_gain)
def _init_w(self, init_gain):
torch.nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(init_gain))
def forward(self, x):
return self.linear_layer(x)
class LinearBN(nn.Module):
def __init__(self,
in_features,
out_features,
bias=True,
init_gain='linear'):
super(LinearBN, self).__init__()
self.linear_layer = torch.nn.Linear(
in_features, out_features, bias=bias)
self.bn = nn.BatchNorm1d(out_features)
self._init_w(init_gain)
def _init_w(self, init_gain):
torch.nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(init_gain))
def forward(self, x):
out = self.linear_layer(x)
if len(out.shape) == 3:
out = out.permute(1, 2, 0)
out = self.bn(out)
if len(out.shape) == 3:
out = out.permute(2, 0, 1)
return out
class Prenet(nn.Module):
def __init__(self,
in_features,
prenet_type="original",
prenet_dropout=True,
out_features=[256, 256],
bias=True):
super(Prenet, self).__init__()
self.prenet_type = prenet_type
self.prenet_dropout = prenet_dropout
in_features = [in_features] + out_features[:-1]
if prenet_type == "bn":
self.layers = nn.ModuleList([
LinearBN(in_size, out_size, bias=bias)
for (in_size, out_size) in zip(in_features, out_features)
])
elif prenet_type == "original":
self.layers = nn.ModuleList([
Linear(in_size, out_size, bias=bias)
for (in_size, out_size) in zip(in_features, out_features)
])
def forward(self, x):
for linear in self.layers:
if self.prenet_dropout:
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training)
else:
x = F.relu(linear(x))
return x

View File

@ -2,38 +2,39 @@
import torch
from torch import nn
from .attention import AttentionRNNCell
from .common_layers import Prenet
class Prenet(nn.Module):
r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
It creates as many layers as given by 'out_features'
# class Prenet(nn.Module):
# r""" Prenet as explained at https://arxiv.org/abs/1703.10135.
# It creates as many layers as given by 'out_features'
Args:
in_features (int): size of the input vector
out_features (int or list): size of each output sample.
If it is a list, for each value, there is created a new layer.
"""
# Args:
# in_features (int): size of the input vector
# out_features (int or list): size of each output sample.
# If it is a list, for each value, there is created a new layer.
# """
def __init__(self, in_features, out_features=[256, 128]):
super(Prenet, self).__init__()
in_features = [in_features] + out_features[:-1]
self.layers = nn.ModuleList([
nn.Linear(in_size, out_size)
for (in_size, out_size) in zip(in_features, out_features)
])
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
# self.init_layers()
# def __init__(self, in_features, out_features=[256, 128]):
# super(Prenet, self).__init__()
# in_features = [in_features] + out_features[:-1]
# self.layers = nn.ModuleList([
# nn.Linear(in_size, out_size)
# for (in_size, out_size) in zip(in_features, out_features)
# ])
# self.relu = nn.ReLU()
# self.dropout = nn.Dropout(0.5)
# # self.init_layers()
def init_layers(self):
for layer in self.layers:
torch.nn.init.xavier_uniform_(
layer.weight, gain=torch.nn.init.calculate_gain('relu'))
# def init_layers(self):
# for layer in self.layers:
# torch.nn.init.xavier_uniform_(
# layer.weight, gain=torch.nn.init.calculate_gain('relu'))
def forward(self, inputs):
for linear in self.layers:
inputs = self.dropout(self.relu(linear(inputs)))
return inputs
# def forward(self, inputs):
# for linear in self.layers:
# inputs = self.dropout(self.relu(linear(inputs)))
# return inputs
class BatchNormConv1d(nn.Module):
@ -301,8 +302,9 @@ class Decoder(nn.Module):
memory_size (int): size of the past window. if <= 0 memory_size = r
"""
def __init__(self, in_features, memory_dim, r, memory_size,
attn_windowing, attn_norm, separate_stopnet):
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing,
attn_norm, prenet_type, prenet_dropout, forward_attn,
trans_agent, location_attn, separate_stopnet):
super(Decoder, self).__init__()
self.r = r
self.in_features = in_features
@ -312,7 +314,10 @@ class Decoder(nn.Module):
self.separate_stopnet = separate_stopnet
# memory -> |Prenet| -> processed_memory
self.prenet = Prenet(
memory_dim * self.memory_size, out_features=[256, 128])
memory_dim * self.memory_size,
prenet_type,
prenet_dropout,
out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
self.attention_rnn = AttentionRNNCell(
out_dim=128,
@ -385,23 +390,22 @@ class Decoder(nn.Module):
stop_tokens = torch.stack(stop_tokens).transpose(0, 1).squeeze(-1)
return outputs, attentions, stop_tokens
def decode(self,
inputs,
t,
mask=None):
def decode(self, inputs, t, mask=None):
# Prenet
processed_memory = self.prenet(self.memory_input)
# Attention RNN
attention_cat = torch.cat(
(self.attention.unsqueeze(1), self.attention_cum.unsqueeze(1)), dim=1)
(self.attention.unsqueeze(1), self.attention_cum.unsqueeze(1)),
dim=1)
self.attention_rnn_hidden, self.current_context_vec, self.attention = self.attention_rnn(
processed_memory, self.current_context_vec, self.attention_rnn_hidden,
inputs, attention_cat, mask, t)
processed_memory, self.current_context_vec,
self.attention_rnn_hidden, inputs, attention_cat, mask, t)
del attention_cat
self.attention_cum += self.attention
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(
torch.cat((self.attention_rnn_hidden, self.current_context_vec), -1))
torch.cat((self.attention_rnn_hidden, self.current_context_vec),
-1))
# Pass through the decoder RNNs
for idx in range(len(self.decoder_rnns)):
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
@ -427,7 +431,8 @@ class Decoder(nn.Module):
self.memory_input = torch.cat([
self.memory_input[:, self.r * self.memory_dim:].clone(),
new_memory
], dim=-1)
],
dim=-1)
else:
self.memory_input = new_memory

View File

@ -9,23 +9,29 @@ from utils.generic_utils import sequence_mask
class Tacotron(nn.Module):
def __init__(self,
num_chars,
r=5,
linear_dim=1025,
mel_dim=80,
r=5,
padding_idx=None,
memory_size=5,
attn_win=False,
attn_norm="sigmoid",
prenet_type="original",
prenet_dropout=True,
forward_attn=False,
trans_agent=False,
location_attn=True,
separate_stopnet=True):
super(Tacotron, self).__init__()
self.r = r
self.mel_dim = mel_dim
self.linear_dim = linear_dim
self.embedding = nn.Embedding(num_chars, 256, padding_idx=padding_idx)
self.embedding = nn.Embedding(num_chars, 256)
self.embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(256)
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win,
attn_norm, separate_stopnet)
attn_norm, prenet_type, prenet_dropout,
forward_attn, trans_agent, location_attn,
separate_stopnet)
self.postnet = PostCBHG(mel_dim)
self.last_linear = nn.Sequential(
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),

View File

@ -18,6 +18,11 @@ file_path = os.path.dirname(os.path.realpath(__file__))
c = load_config(os.path.join(file_path, 'test_config.json'))
def count_parameters(model):
r"""Count number of trainable parameters in a network"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class TacotronTrainTest(unittest.TestCase):
def test_train_step(self):
input = torch.randint(0, 24, (8, 128)).long().to(device)
@ -33,13 +38,19 @@ class TacotronTrainTest(unittest.TestCase):
stop_targets = stop_targets.view(input.shape[0],
stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze()
stop_targets = (stop_targets.sum(2) >
0.0).unsqueeze(2).float().squeeze()
criterion = L1LossMasked().to(device)
criterion_st = nn.BCEWithLogitsLoss().to(device)
model = Tacotron(32, c.audio['num_freq'], c.audio['num_mels'],
c.r, memory_size=c.memory_size).to(device)
model = Tacotron(
32,
linear_dim=c.audio['num_freq'],
mel_dim=c.audio['num_mels'],
r=c.r,
memory_size=c.memory_size).to(device)
model.train()
print(" > Num parameters for Tacotron model:%s"%(count_parameters(model)))
model_ref = copy.deepcopy(model)
count = 0
for param, param_ref in zip(model.parameters(),

View File

@ -251,9 +251,16 @@ def setup_model(num_chars, c):
model = MyModel(
num_chars=num_chars,
r=c.r,
linear_dim=1025,
mel_dim=80,
memory_size=c.memory_size,
attn_win=c.windowing,
attn_norm=c.attention_norm,
memory_size=c.memory_size,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
location_attn=c.location_attn,
separate_stopnet=c.separate_stopnet)
elif c.model.lower() == "tacotron2":
model = MyModel(