From ba492f43be211cedce9d967caaa3857efd89c42d Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Mon, 27 May 2019 14:40:28 +0200 Subject: [PATCH] Set tacotron model parameters to adap to common_layers.py - Prenet and Attention --- layers/common_layers.py | 83 ++++++++++++++++++++++++++++++++++++ layers/tacotron.py | 81 ++++++++++++++++++----------------- models/tacotron.py | 14 ++++-- tests/test_tacotron_model.py | 17 ++++++-- utils/generic_utils.py | 9 +++- 5 files changed, 158 insertions(+), 46 deletions(-) create mode 100644 layers/common_layers.py diff --git a/layers/common_layers.py b/layers/common_layers.py new file mode 100644 index 00000000..c5704f62 --- /dev/null +++ b/layers/common_layers.py @@ -0,0 +1,83 @@ +from math import sqrt +import torch +from torch.autograd import Variable +from torch import nn +from torch.nn import functional as F + + +class Linear(nn.Module): + def __init__(self, + in_features, + out_features, + bias=True, + init_gain='linear'): + super(Linear, self).__init__() + self.linear_layer = torch.nn.Linear( + in_features, out_features, bias=bias) + self._init_w(init_gain) + + def _init_w(self, init_gain): + torch.nn.init.xavier_uniform_( + self.linear_layer.weight, + gain=torch.nn.init.calculate_gain(init_gain)) + + def forward(self, x): + return self.linear_layer(x) + + +class LinearBN(nn.Module): + def __init__(self, + in_features, + out_features, + bias=True, + init_gain='linear'): + super(LinearBN, self).__init__() + self.linear_layer = torch.nn.Linear( + in_features, out_features, bias=bias) + self.bn = nn.BatchNorm1d(out_features) + self._init_w(init_gain) + + def _init_w(self, init_gain): + torch.nn.init.xavier_uniform_( + self.linear_layer.weight, + gain=torch.nn.init.calculate_gain(init_gain)) + + def forward(self, x): + out = self.linear_layer(x) + if len(out.shape) == 3: + out = out.permute(1, 2, 0) + out = self.bn(out) + if len(out.shape) == 3: + out = out.permute(2, 0, 1) + return out + + +class Prenet(nn.Module): + def __init__(self, + in_features, + prenet_type="original", + prenet_dropout=True, + out_features=[256, 256], + bias=True): + super(Prenet, self).__init__() + self.prenet_type = prenet_type + self.prenet_dropout = prenet_dropout + in_features = [in_features] + out_features[:-1] + if prenet_type == "bn": + self.layers = nn.ModuleList([ + LinearBN(in_size, out_size, bias=bias) + for (in_size, out_size) in zip(in_features, out_features) + ]) + elif prenet_type == "original": + self.layers = nn.ModuleList([ + Linear(in_size, out_size, bias=bias) + for (in_size, out_size) in zip(in_features, out_features) + ]) + + def forward(self, x): + for linear in self.layers: + if self.prenet_dropout: + x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training) + else: + x = F.relu(linear(x)) + return x \ No newline at end of file diff --git a/layers/tacotron.py b/layers/tacotron.py index 931388bf..690407b7 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -2,38 +2,39 @@ import torch from torch import nn from .attention import AttentionRNNCell +from .common_layers import Prenet -class Prenet(nn.Module): - r""" Prenet as explained at https://arxiv.org/abs/1703.10135. - It creates as many layers as given by 'out_features' +# class Prenet(nn.Module): +# r""" Prenet as explained at https://arxiv.org/abs/1703.10135. +# It creates as many layers as given by 'out_features' - Args: - in_features (int): size of the input vector - out_features (int or list): size of each output sample. - If it is a list, for each value, there is created a new layer. - """ +# Args: +# in_features (int): size of the input vector +# out_features (int or list): size of each output sample. +# If it is a list, for each value, there is created a new layer. +# """ - def __init__(self, in_features, out_features=[256, 128]): - super(Prenet, self).__init__() - in_features = [in_features] + out_features[:-1] - self.layers = nn.ModuleList([ - nn.Linear(in_size, out_size) - for (in_size, out_size) in zip(in_features, out_features) - ]) - self.relu = nn.ReLU() - self.dropout = nn.Dropout(0.5) - # self.init_layers() +# def __init__(self, in_features, out_features=[256, 128]): +# super(Prenet, self).__init__() +# in_features = [in_features] + out_features[:-1] +# self.layers = nn.ModuleList([ +# nn.Linear(in_size, out_size) +# for (in_size, out_size) in zip(in_features, out_features) +# ]) +# self.relu = nn.ReLU() +# self.dropout = nn.Dropout(0.5) +# # self.init_layers() - def init_layers(self): - for layer in self.layers: - torch.nn.init.xavier_uniform_( - layer.weight, gain=torch.nn.init.calculate_gain('relu')) +# def init_layers(self): +# for layer in self.layers: +# torch.nn.init.xavier_uniform_( +# layer.weight, gain=torch.nn.init.calculate_gain('relu')) - def forward(self, inputs): - for linear in self.layers: - inputs = self.dropout(self.relu(linear(inputs))) - return inputs +# def forward(self, inputs): +# for linear in self.layers: +# inputs = self.dropout(self.relu(linear(inputs))) +# return inputs class BatchNormConv1d(nn.Module): @@ -301,8 +302,9 @@ class Decoder(nn.Module): memory_size (int): size of the past window. if <= 0 memory_size = r """ - def __init__(self, in_features, memory_dim, r, memory_size, - attn_windowing, attn_norm, separate_stopnet): + def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing, + attn_norm, prenet_type, prenet_dropout, forward_attn, + trans_agent, location_attn, separate_stopnet): super(Decoder, self).__init__() self.r = r self.in_features = in_features @@ -312,7 +314,10 @@ class Decoder(nn.Module): self.separate_stopnet = separate_stopnet # memory -> |Prenet| -> processed_memory self.prenet = Prenet( - memory_dim * self.memory_size, out_features=[256, 128]) + memory_dim * self.memory_size, + prenet_type, + prenet_dropout, + out_features=[256, 128]) # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State self.attention_rnn = AttentionRNNCell( out_dim=128, @@ -385,23 +390,22 @@ class Decoder(nn.Module): stop_tokens = torch.stack(stop_tokens).transpose(0, 1).squeeze(-1) return outputs, attentions, stop_tokens - def decode(self, - inputs, - t, - mask=None): + def decode(self, inputs, t, mask=None): # Prenet processed_memory = self.prenet(self.memory_input) # Attention RNN attention_cat = torch.cat( - (self.attention.unsqueeze(1), self.attention_cum.unsqueeze(1)), dim=1) + (self.attention.unsqueeze(1), self.attention_cum.unsqueeze(1)), + dim=1) self.attention_rnn_hidden, self.current_context_vec, self.attention = self.attention_rnn( - processed_memory, self.current_context_vec, self.attention_rnn_hidden, - inputs, attention_cat, mask, t) + processed_memory, self.current_context_vec, + self.attention_rnn_hidden, inputs, attention_cat, mask, t) del attention_cat self.attention_cum += self.attention # Concat RNN output and attention context vector decoder_input = self.project_to_decoder_in( - torch.cat((self.attention_rnn_hidden, self.current_context_vec), -1)) + torch.cat((self.attention_rnn_hidden, self.current_context_vec), + -1)) # Pass through the decoder RNNs for idx in range(len(self.decoder_rnns)): self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx]( @@ -427,7 +431,8 @@ class Decoder(nn.Module): self.memory_input = torch.cat([ self.memory_input[:, self.r * self.memory_dim:].clone(), new_memory - ], dim=-1) + ], + dim=-1) else: self.memory_input = new_memory diff --git a/models/tacotron.py b/models/tacotron.py index 4243a039..362bf8b5 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -9,23 +9,29 @@ from utils.generic_utils import sequence_mask class Tacotron(nn.Module): def __init__(self, num_chars, + r=5, linear_dim=1025, mel_dim=80, - r=5, - padding_idx=None, memory_size=5, attn_win=False, attn_norm="sigmoid", + prenet_type="original", + prenet_dropout=True, + forward_attn=False, + trans_agent=False, + location_attn=True, separate_stopnet=True): super(Tacotron, self).__init__() self.r = r self.mel_dim = mel_dim self.linear_dim = linear_dim - self.embedding = nn.Embedding(num_chars, 256, padding_idx=padding_idx) + self.embedding = nn.Embedding(num_chars, 256) self.embedding.weight.data.normal_(0, 0.3) self.encoder = Encoder(256) self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win, - attn_norm, separate_stopnet) + attn_norm, prenet_type, prenet_dropout, + forward_attn, trans_agent, location_attn, + separate_stopnet) self.postnet = PostCBHG(mel_dim) self.last_linear = nn.Sequential( nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim), diff --git a/tests/test_tacotron_model.py b/tests/test_tacotron_model.py index e98aa5fb..e0580107 100644 --- a/tests/test_tacotron_model.py +++ b/tests/test_tacotron_model.py @@ -18,6 +18,11 @@ file_path = os.path.dirname(os.path.realpath(__file__)) c = load_config(os.path.join(file_path, 'test_config.json')) +def count_parameters(model): + r"""Count number of trainable parameters in a network""" + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + class TacotronTrainTest(unittest.TestCase): def test_train_step(self): input = torch.randint(0, 24, (8, 128)).long().to(device) @@ -33,13 +38,19 @@ class TacotronTrainTest(unittest.TestCase): stop_targets = stop_targets.view(input.shape[0], stop_targets.size(1) // c.r, -1) - stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze() + stop_targets = (stop_targets.sum(2) > + 0.0).unsqueeze(2).float().squeeze() criterion = L1LossMasked().to(device) criterion_st = nn.BCEWithLogitsLoss().to(device) - model = Tacotron(32, c.audio['num_freq'], c.audio['num_mels'], - c.r, memory_size=c.memory_size).to(device) + model = Tacotron( + 32, + linear_dim=c.audio['num_freq'], + mel_dim=c.audio['num_mels'], + r=c.r, + memory_size=c.memory_size).to(device) model.train() + print(" > Num parameters for Tacotron model:%s"%(count_parameters(model))) model_ref = copy.deepcopy(model) count = 0 for param, param_ref in zip(model.parameters(), diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 1158ea06..77fb4dc2 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -251,9 +251,16 @@ def setup_model(num_chars, c): model = MyModel( num_chars=num_chars, r=c.r, + linear_dim=1025, + mel_dim=80, + memory_size=c.memory_size, attn_win=c.windowing, attn_norm=c.attention_norm, - memory_size=c.memory_size, + prenet_type=c.prenet_type, + prenet_dropout=c.prenet_dropout, + forward_attn=c.use_forward_attn, + trans_agent=c.transition_agent, + location_attn=c.location_attn, separate_stopnet=c.separate_stopnet) elif c.model.lower() == "tacotron2": model = MyModel(