setup_model externally based on model selection. Make forward attention and prenet type configurable in config.json

pull/10/head
Eren Golge 2019-04-05 17:49:18 +02:00
parent 043e49f367
commit 961af0f5cd
6 changed files with 78 additions and 28 deletions

View File

@ -38,8 +38,10 @@
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"windowing": false, // Enables attention windowing. Used only in eval mode.
"memory_size": 5, // TO BE IMPLEMENTED -- memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
"attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
"use_forward_attn": false, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"eval_batch_size":16,
@ -47,9 +49,9 @@
"wd": 0.000001, // Weight decay weight.
"checkpoint": true, // If true, it saves checkpoints per "save_step"
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
"print_step": 100, // Number of steps to log traning on console.
"print_step": 100, // Number of steps to log traning on console.
"tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
"batch_group_size": 8, //Number of batches to shuffle after bucketing.
"batch_group_size": 8, // Number of batches to shuffle after bucketing.
"run_eval": true,
"test_delay_epochs": 2, //Until attention is aligned, testing only wastes computation time.

View File

@ -38,8 +38,10 @@
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"windowing": false, // Enables attention windowing. Used only in eval mode.
"memory_size": 5, // TO BE IMPLEMENTED -- memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
"attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
"use_forward_attn": false, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"eval_batch_size":16,

View File

@ -53,17 +53,27 @@ class LinearBN(nn.Module):
class Prenet(nn.Module):
def __init__(self, in_features, out_features=[256, 256]):
def __init__(self, in_features, prenet_type, out_features=[256, 256]):
super(Prenet, self).__init__()
self.prenet_type = prenet_type
in_features = [in_features] + out_features[:-1]
self.layers = nn.ModuleList([
LinearBN(in_size, out_size, bias=False)
for (in_size, out_size) in zip(in_features, out_features)
])
if prenet_type == "original":
self.layers = nn.ModuleList([
LinearBN(in_size, out_size, bias=False)
for (in_size, out_size) in zip(in_features, out_features)
])
elif prenet_type == "bn":
self.layers = nn.ModuleList(
[Linear(in_size, out_size, bias=False)
for (in_size, out_size) in zip(in_features, out_features)
])
def forward(self, x):
for linear in self.layers:
x = F.relu(linear(x))
if self.prenet_type == "original":
x = F.relu(linear(x))
elif self.prenet_type == "bn":
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training)
return x
@ -112,7 +122,7 @@ class LocationLayer(nn.Module):
class Attention(nn.Module):
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
attention_location_n_filters, attention_location_kernel_size,
windowing, norm):
windowing, norm, forward_attn):
super(Attention, self).__init__()
self.query_layer = Linear(
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
@ -126,11 +136,21 @@ class Attention(nn.Module):
self.windowing = windowing
self.win_idx = None
self.norm = norm
self.forward_attn = forward_attn
def init_win_idx(self):
self.win_idx = -1
self.win_back = 1
self.win_front = 3
def init_forward_attn_state(self, inputs):
"""
Init forward attention states
"""
B = inputs.shape[0]
T = inputs.shape[1]
self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1]], dim=1).to(inputs.device)
self.u = (0.5 * torch.ones([B, 1])).to(inputs.device)
def get_attention(self, query, processed_inputs, attention_cat):
processed_query = self.query_layer(query.unsqueeze(1))
@ -170,9 +190,19 @@ class Attention(nn.Module):
attention).sum(dim=1).unsqueeze(1)
else:
raise RuntimeError("Unknown value for attention norm type")
context = torch.bmm(alignment.unsqueeze(1), inputs)
context = context.squeeze(1)
return context, alignment
if self.forward_attn:
# forward attention
prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device)
self.alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-7) * alignment
alpha_norm = self.alpha / self.alpha.sum(dim=1).unsqueeze(1)
# compute context
context = torch.bmm(alpha_norm.unsqueeze(1), inputs)
context = context.squeeze(1)
return context, alpha_norm, alignment
else:
context = torch.bmm(alignment.unsqueeze(1), inputs)
context = context.squeeze(1)
return context, alignment, alignment
class Postnet(nn.Module):
@ -242,7 +272,7 @@ class Encoder(nn.Module):
# adapted from https://github.com/NVIDIA/tacotron2/
class Decoder(nn.Module):
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm):
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, forward_attn):
super(Decoder, self).__init__()
self.mel_channels = inputs_dim
self.r = r
@ -255,14 +285,14 @@ class Decoder(nn.Module):
self.p_attention_dropout = 0.1
self.p_decoder_dropout = 0.1
self.prenet = Prenet(self.mel_channels * r,
self.prenet = Prenet(self.mel_channels * r, prenet_type,
[self.prenet_dim, self.prenet_dim])
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
self.attention_rnn_dim)
self.attention_layer = Attention(self.attention_rnn_dim, in_features,
128, 32, 31, attn_win, attn_norm)
128, 32, 31, attn_win, attn_norm, forward_attn)
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
self.decoder_rnn_dim, 1)
@ -340,11 +370,11 @@ class Decoder(nn.Module):
attention_cat = torch.cat((self.attention_weights.unsqueeze(1),
self.attention_weights_cum.unsqueeze(1)),
dim=1)
self.context, self.attention_weights = self.attention_layer(
self.context, self.attention_weights, alignments = self.attention_layer(
self.attention_hidden, self.inputs, self.processed_inputs,
attention_cat, self.mask)
self.attention_weights_cum += self.attention_weights
self.attention_weights_cum += alignments
memory = torch.cat(
(self.attention_hidden, self.context), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
@ -372,6 +402,8 @@ class Decoder(nn.Module):
memories = self.prenet(memories)
self._init_states(inputs, mask=mask)
if self.attention_layer.forward_attn:
self.attention_layer.init_forward_attn_state(inputs)
outputs, stop_tokens, alignments = [], [], []
while len(outputs) < memories.size(0) - 1:
@ -392,6 +424,9 @@ class Decoder(nn.Module):
self._init_states(inputs, mask=None)
self.attention_layer.init_win_idx()
if self.attention_layer.forward_attn:
self.attention_layer.init_forward_attn_state(inputs)
outputs, stop_tokens, alignments, t = [], [], [], 0
stop_flags = [True, False, False]
stop_count = 0
@ -433,8 +468,9 @@ class Decoder(nn.Module):
self._init_states(inputs, mask=None, keep_states=True)
self.attention_layer.init_win_idx()
self.attention_layer.init_forward_attn_state()
outputs, gate_outputs, alignments, t = [], [], [], 0
stop_flags = [False, False]
stop_flags = [False, False, False]
stop_count = 0
while True:
memory = self.prenet(self.memory_truncated)
@ -454,6 +490,7 @@ class Decoder(nn.Module):
elif len(outputs) == self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")
break
self.memory_truncated = mel_output
t += 1

View File

@ -9,7 +9,7 @@ from utils.generic_utils import sequence_mask
# TODO: match function arguments with tacotron
class Tacotron2(nn.Module):
def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax"):
def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", forward_attn=False):
super(Tacotron2, self).__init__()
self.n_mel_channels = 80
self.n_frames_per_step = r
@ -18,7 +18,7 @@ class Tacotron2(nn.Module):
val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val)
self.encoder = Encoder(512)
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm)
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, forward_attn)
self.postnet = Postnet(self.n_mel_channels)
def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):

View File

@ -23,7 +23,7 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters,
load_config, lr_decay,
remove_experiment_folder, save_best_model,
save_checkpoint, sequence_mask, weight_decay,
set_init_dict, copy_config_file)
set_init_dict, copy_config_file, setup_model)
from utils.logger import Logger
from utils.synthesis import synthesis
from utils.text.symbols import phonemes, symbols
@ -375,7 +375,7 @@ def main(args):
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm)
model = setup_model(num_chars, c)
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
@ -528,9 +528,6 @@ if __name__ == '__main__':
# Conditional imports
preprocessor = importlib.import_module('datasets.preprocess')
preprocessor = getattr(preprocessor, c.dataset.lower())
print(" > Using model: {}".format(c.model))
MyModel = importlib.import_module('models.'+c.model.lower())
MyModel = getattr(MyModel, c.model)
# Audio processor
ap = AudioProcessor(**c.audio)

View File

@ -8,6 +8,7 @@ import datetime
import json
import torch
import subprocess
import importlib
import numpy as np
from collections import OrderedDict
from torch.autograd import Variable
@ -236,4 +237,15 @@ def set_init_dict(model_dict, checkpoint, c):
# 4. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
print(" | > {} / {} layers are initialized".format(len(pretrained_dict), len(model_dict)))
return model_dict
return model_dict
def setup_model(num_chars, c):
print(" > Using model: {}".format(c.model))
MyModel = importlib.import_module('models.'+c.model.lower())
MyModel = getattr(MyModel, c.model)
if c.model.lower() == "tacotron":
model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm, memory_size=c.memory_size)
elif c.model.lower() == "tacotron2":
model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm, prenet_type=c.prenet_type, forward_attn=c.use_forward_attn)
return model