diff --git a/config.json b/config.json index 46ea7865..e6cc8a53 100644 --- a/config.json +++ b/config.json @@ -38,8 +38,10 @@ "lr_decay": false, // if true, Noam learning rate decaying is applied through training. "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" "windowing": false, // Enables attention windowing. Used only in eval mode. - "memory_size": 5, // TO BE IMPLEMENTED -- memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. + "memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. "attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. + "prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn". + "use_forward_attn": false, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster. "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. "eval_batch_size":16, @@ -47,9 +49,9 @@ "wd": 0.000001, // Weight decay weight. "checkpoint": true, // If true, it saves checkpoints per "save_step" "save_step": 1000, // Number of training steps expected to save traning stats and checkpoints. - "print_step": 100, // Number of steps to log traning on console. + "print_step": 100, // Number of steps to log traning on console. "tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. - "batch_group_size": 8, //Number of batches to shuffle after bucketing. + "batch_group_size": 8, // Number of batches to shuffle after bucketing. "run_eval": true, "test_delay_epochs": 2, //Until attention is aligned, testing only wastes computation time. diff --git a/config_cluster.json b/config_cluster.json index b61dfedc..0af1d556 100644 --- a/config_cluster.json +++ b/config_cluster.json @@ -38,8 +38,10 @@ "lr_decay": false, // if true, Noam learning rate decaying is applied through training. "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" "windowing": false, // Enables attention windowing. Used only in eval mode. - "memory_size": 5, // TO BE IMPLEMENTED -- memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. + "memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. "attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. + "prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn". + "use_forward_attn": false, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster. "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. "eval_batch_size":16, diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 6b413d10..4e41f366 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -53,17 +53,27 @@ class LinearBN(nn.Module): class Prenet(nn.Module): - def __init__(self, in_features, out_features=[256, 256]): + def __init__(self, in_features, prenet_type, out_features=[256, 256]): super(Prenet, self).__init__() + self.prenet_type = prenet_type in_features = [in_features] + out_features[:-1] - self.layers = nn.ModuleList([ - LinearBN(in_size, out_size, bias=False) - for (in_size, out_size) in zip(in_features, out_features) - ]) + if prenet_type == "original": + self.layers = nn.ModuleList([ + LinearBN(in_size, out_size, bias=False) + for (in_size, out_size) in zip(in_features, out_features) + ]) + elif prenet_type == "bn": + self.layers = nn.ModuleList( + [Linear(in_size, out_size, bias=False) + for (in_size, out_size) in zip(in_features, out_features) + ]) def forward(self, x): for linear in self.layers: - x = F.relu(linear(x)) + if self.prenet_type == "original": + x = F.relu(linear(x)) + elif self.prenet_type == "bn": + x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training) return x @@ -112,7 +122,7 @@ class LocationLayer(nn.Module): class Attention(nn.Module): def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, attention_location_n_filters, attention_location_kernel_size, - windowing, norm): + windowing, norm, forward_attn): super(Attention, self).__init__() self.query_layer = Linear( attention_rnn_dim, attention_dim, bias=False, init_gain='tanh') @@ -126,11 +136,21 @@ class Attention(nn.Module): self.windowing = windowing self.win_idx = None self.norm = norm + self.forward_attn = forward_attn def init_win_idx(self): self.win_idx = -1 self.win_back = 1 self.win_front = 3 + + def init_forward_attn_state(self, inputs): + """ + Init forward attention states + """ + B = inputs.shape[0] + T = inputs.shape[1] + self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1]], dim=1).to(inputs.device) + self.u = (0.5 * torch.ones([B, 1])).to(inputs.device) def get_attention(self, query, processed_inputs, attention_cat): processed_query = self.query_layer(query.unsqueeze(1)) @@ -170,9 +190,19 @@ class Attention(nn.Module): attention).sum(dim=1).unsqueeze(1) else: raise RuntimeError("Unknown value for attention norm type") - context = torch.bmm(alignment.unsqueeze(1), inputs) - context = context.squeeze(1) - return context, alignment + if self.forward_attn: + # forward attention + prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device) + self.alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-7) * alignment + alpha_norm = self.alpha / self.alpha.sum(dim=1).unsqueeze(1) + # compute context + context = torch.bmm(alpha_norm.unsqueeze(1), inputs) + context = context.squeeze(1) + return context, alpha_norm, alignment + else: + context = torch.bmm(alignment.unsqueeze(1), inputs) + context = context.squeeze(1) + return context, alignment, alignment class Postnet(nn.Module): @@ -242,7 +272,7 @@ class Encoder(nn.Module): # adapted from https://github.com/NVIDIA/tacotron2/ class Decoder(nn.Module): - def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm): + def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, forward_attn): super(Decoder, self).__init__() self.mel_channels = inputs_dim self.r = r @@ -255,14 +285,14 @@ class Decoder(nn.Module): self.p_attention_dropout = 0.1 self.p_decoder_dropout = 0.1 - self.prenet = Prenet(self.mel_channels * r, + self.prenet = Prenet(self.mel_channels * r, prenet_type, [self.prenet_dim, self.prenet_dim]) self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features, self.attention_rnn_dim) self.attention_layer = Attention(self.attention_rnn_dim, in_features, - 128, 32, 31, attn_win, attn_norm) + 128, 32, 31, attn_win, attn_norm, forward_attn) self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features, self.decoder_rnn_dim, 1) @@ -340,11 +370,11 @@ class Decoder(nn.Module): attention_cat = torch.cat((self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1) - self.context, self.attention_weights = self.attention_layer( + self.context, self.attention_weights, alignments = self.attention_layer( self.attention_hidden, self.inputs, self.processed_inputs, attention_cat, self.mask) - self.attention_weights_cum += self.attention_weights + self.attention_weights_cum += alignments memory = torch.cat( (self.attention_hidden, self.context), -1) self.decoder_hidden, self.decoder_cell = self.decoder_rnn( @@ -372,6 +402,8 @@ class Decoder(nn.Module): memories = self.prenet(memories) self._init_states(inputs, mask=mask) + if self.attention_layer.forward_attn: + self.attention_layer.init_forward_attn_state(inputs) outputs, stop_tokens, alignments = [], [], [] while len(outputs) < memories.size(0) - 1: @@ -392,6 +424,9 @@ class Decoder(nn.Module): self._init_states(inputs, mask=None) self.attention_layer.init_win_idx() + if self.attention_layer.forward_attn: + self.attention_layer.init_forward_attn_state(inputs) + outputs, stop_tokens, alignments, t = [], [], [], 0 stop_flags = [True, False, False] stop_count = 0 @@ -433,8 +468,9 @@ class Decoder(nn.Module): self._init_states(inputs, mask=None, keep_states=True) self.attention_layer.init_win_idx() + self.attention_layer.init_forward_attn_state() outputs, gate_outputs, alignments, t = [], [], [], 0 - stop_flags = [False, False] + stop_flags = [False, False, False] stop_count = 0 while True: memory = self.prenet(self.memory_truncated) @@ -454,6 +490,7 @@ class Decoder(nn.Module): elif len(outputs) == self.max_decoder_steps: print(" | > Decoder stopped with 'max_decoder_steps") break + self.memory_truncated = mel_output t += 1 diff --git a/models/tacotron2.py b/models/tacotron2.py index 671e6bb8..dadc4a24 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -9,7 +9,7 @@ from utils.generic_utils import sequence_mask # TODO: match function arguments with tacotron class Tacotron2(nn.Module): - def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax"): + def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", forward_attn=False): super(Tacotron2, self).__init__() self.n_mel_channels = 80 self.n_frames_per_step = r @@ -18,7 +18,7 @@ class Tacotron2(nn.Module): val = sqrt(3.0) * std # uniform bounds for std self.embedding.weight.data.uniform_(-val, val) self.encoder = Encoder(512) - self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm) + self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, forward_attn) self.postnet = Postnet(self.n_mel_channels) def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments): diff --git a/train.py b/train.py index 97eec378..8f818ac2 100644 --- a/train.py +++ b/train.py @@ -23,7 +23,7 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters, load_config, lr_decay, remove_experiment_folder, save_best_model, save_checkpoint, sequence_mask, weight_decay, - set_init_dict, copy_config_file) + set_init_dict, copy_config_file, setup_model) from utils.logger import Logger from utils.synthesis import synthesis from utils.text.symbols import phonemes, symbols @@ -375,7 +375,7 @@ def main(args): init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) num_chars = len(phonemes) if c.use_phonemes else len(symbols) - model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm) + model = setup_model(num_chars, c) print(" | > Num output units : {}".format(ap.num_freq), flush=True) @@ -528,9 +528,6 @@ if __name__ == '__main__': # Conditional imports preprocessor = importlib.import_module('datasets.preprocess') preprocessor = getattr(preprocessor, c.dataset.lower()) - print(" > Using model: {}".format(c.model)) - MyModel = importlib.import_module('models.'+c.model.lower()) - MyModel = getattr(MyModel, c.model) # Audio processor ap = AudioProcessor(**c.audio) diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 1a791290..7eec4e9c 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -8,6 +8,7 @@ import datetime import json import torch import subprocess +import importlib import numpy as np from collections import OrderedDict from torch.autograd import Variable @@ -236,4 +237,15 @@ def set_init_dict(model_dict, checkpoint, c): # 4. overwrite entries in the existing state dict model_dict.update(pretrained_dict) print(" | > {} / {} layers are initialized".format(len(pretrained_dict), len(model_dict))) - return model_dict \ No newline at end of file + return model_dict + + +def setup_model(num_chars, c): + print(" > Using model: {}".format(c.model)) + MyModel = importlib.import_module('models.'+c.model.lower()) + MyModel = getattr(MyModel, c.model) + if c.model.lower() == "tacotron": + model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm, memory_size=c.memory_size) + elif c.model.lower() == "tacotron2": + model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm, prenet_type=c.prenet_type, forward_attn=c.use_forward_attn) + return model \ No newline at end of file