mirror of https://github.com/coqui-ai/TTS.git
setup_model externally based on model selection. Make forward attention and prenet type configurable in config.json
parent
043e49f367
commit
961af0f5cd
|
@ -38,8 +38,10 @@
|
|||
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"windowing": false, // Enables attention windowing. Used only in eval mode.
|
||||
"memory_size": 5, // TO BE IMPLEMENTED -- memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
||||
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
||||
"attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
||||
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
|
||||
"use_forward_attn": false, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
|
||||
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||
"eval_batch_size":16,
|
||||
|
@ -47,9 +49,9 @@
|
|||
"wd": 0.000001, // Weight decay weight.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
|
||||
"print_step": 100, // Number of steps to log traning on console.
|
||||
"print_step": 100, // Number of steps to log traning on console.
|
||||
"tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
"batch_group_size": 8, //Number of batches to shuffle after bucketing.
|
||||
"batch_group_size": 8, // Number of batches to shuffle after bucketing.
|
||||
|
||||
"run_eval": true,
|
||||
"test_delay_epochs": 2, //Until attention is aligned, testing only wastes computation time.
|
||||
|
|
|
@ -38,8 +38,10 @@
|
|||
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"windowing": false, // Enables attention windowing. Used only in eval mode.
|
||||
"memory_size": 5, // TO BE IMPLEMENTED -- memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
||||
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
||||
"attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
||||
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
|
||||
"use_forward_attn": false, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
|
||||
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||
"eval_batch_size":16,
|
||||
|
|
|
@ -53,17 +53,27 @@ class LinearBN(nn.Module):
|
|||
|
||||
|
||||
class Prenet(nn.Module):
|
||||
def __init__(self, in_features, out_features=[256, 256]):
|
||||
def __init__(self, in_features, prenet_type, out_features=[256, 256]):
|
||||
super(Prenet, self).__init__()
|
||||
self.prenet_type = prenet_type
|
||||
in_features = [in_features] + out_features[:-1]
|
||||
self.layers = nn.ModuleList([
|
||||
LinearBN(in_size, out_size, bias=False)
|
||||
for (in_size, out_size) in zip(in_features, out_features)
|
||||
])
|
||||
if prenet_type == "original":
|
||||
self.layers = nn.ModuleList([
|
||||
LinearBN(in_size, out_size, bias=False)
|
||||
for (in_size, out_size) in zip(in_features, out_features)
|
||||
])
|
||||
elif prenet_type == "bn":
|
||||
self.layers = nn.ModuleList(
|
||||
[Linear(in_size, out_size, bias=False)
|
||||
for (in_size, out_size) in zip(in_features, out_features)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
for linear in self.layers:
|
||||
x = F.relu(linear(x))
|
||||
if self.prenet_type == "original":
|
||||
x = F.relu(linear(x))
|
||||
elif self.prenet_type == "bn":
|
||||
x = F.dropout(F.relu(linear(x)), p=0.5, training=self.training)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -112,7 +122,7 @@ class LocationLayer(nn.Module):
|
|||
class Attention(nn.Module):
|
||||
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
||||
attention_location_n_filters, attention_location_kernel_size,
|
||||
windowing, norm):
|
||||
windowing, norm, forward_attn):
|
||||
super(Attention, self).__init__()
|
||||
self.query_layer = Linear(
|
||||
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
|
||||
|
@ -126,11 +136,21 @@ class Attention(nn.Module):
|
|||
self.windowing = windowing
|
||||
self.win_idx = None
|
||||
self.norm = norm
|
||||
self.forward_attn = forward_attn
|
||||
|
||||
def init_win_idx(self):
|
||||
self.win_idx = -1
|
||||
self.win_back = 1
|
||||
self.win_front = 3
|
||||
|
||||
def init_forward_attn_state(self, inputs):
|
||||
"""
|
||||
Init forward attention states
|
||||
"""
|
||||
B = inputs.shape[0]
|
||||
T = inputs.shape[1]
|
||||
self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1]], dim=1).to(inputs.device)
|
||||
self.u = (0.5 * torch.ones([B, 1])).to(inputs.device)
|
||||
|
||||
def get_attention(self, query, processed_inputs, attention_cat):
|
||||
processed_query = self.query_layer(query.unsqueeze(1))
|
||||
|
@ -170,9 +190,19 @@ class Attention(nn.Module):
|
|||
attention).sum(dim=1).unsqueeze(1)
|
||||
else:
|
||||
raise RuntimeError("Unknown value for attention norm type")
|
||||
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
||||
context = context.squeeze(1)
|
||||
return context, alignment
|
||||
if self.forward_attn:
|
||||
# forward attention
|
||||
prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device)
|
||||
self.alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-7) * alignment
|
||||
alpha_norm = self.alpha / self.alpha.sum(dim=1).unsqueeze(1)
|
||||
# compute context
|
||||
context = torch.bmm(alpha_norm.unsqueeze(1), inputs)
|
||||
context = context.squeeze(1)
|
||||
return context, alpha_norm, alignment
|
||||
else:
|
||||
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
||||
context = context.squeeze(1)
|
||||
return context, alignment, alignment
|
||||
|
||||
|
||||
class Postnet(nn.Module):
|
||||
|
@ -242,7 +272,7 @@ class Encoder(nn.Module):
|
|||
|
||||
# adapted from https://github.com/NVIDIA/tacotron2/
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm):
|
||||
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, forward_attn):
|
||||
super(Decoder, self).__init__()
|
||||
self.mel_channels = inputs_dim
|
||||
self.r = r
|
||||
|
@ -255,14 +285,14 @@ class Decoder(nn.Module):
|
|||
self.p_attention_dropout = 0.1
|
||||
self.p_decoder_dropout = 0.1
|
||||
|
||||
self.prenet = Prenet(self.mel_channels * r,
|
||||
self.prenet = Prenet(self.mel_channels * r, prenet_type,
|
||||
[self.prenet_dim, self.prenet_dim])
|
||||
|
||||
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
||||
self.attention_rnn_dim)
|
||||
|
||||
self.attention_layer = Attention(self.attention_rnn_dim, in_features,
|
||||
128, 32, 31, attn_win, attn_norm)
|
||||
128, 32, 31, attn_win, attn_norm, forward_attn)
|
||||
|
||||
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
|
||||
self.decoder_rnn_dim, 1)
|
||||
|
@ -340,11 +370,11 @@ class Decoder(nn.Module):
|
|||
attention_cat = torch.cat((self.attention_weights.unsqueeze(1),
|
||||
self.attention_weights_cum.unsqueeze(1)),
|
||||
dim=1)
|
||||
self.context, self.attention_weights = self.attention_layer(
|
||||
self.context, self.attention_weights, alignments = self.attention_layer(
|
||||
self.attention_hidden, self.inputs, self.processed_inputs,
|
||||
attention_cat, self.mask)
|
||||
|
||||
self.attention_weights_cum += self.attention_weights
|
||||
self.attention_weights_cum += alignments
|
||||
memory = torch.cat(
|
||||
(self.attention_hidden, self.context), -1)
|
||||
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
||||
|
@ -372,6 +402,8 @@ class Decoder(nn.Module):
|
|||
memories = self.prenet(memories)
|
||||
|
||||
self._init_states(inputs, mask=mask)
|
||||
if self.attention_layer.forward_attn:
|
||||
self.attention_layer.init_forward_attn_state(inputs)
|
||||
|
||||
outputs, stop_tokens, alignments = [], [], []
|
||||
while len(outputs) < memories.size(0) - 1:
|
||||
|
@ -392,6 +424,9 @@ class Decoder(nn.Module):
|
|||
self._init_states(inputs, mask=None)
|
||||
|
||||
self.attention_layer.init_win_idx()
|
||||
if self.attention_layer.forward_attn:
|
||||
self.attention_layer.init_forward_attn_state(inputs)
|
||||
|
||||
outputs, stop_tokens, alignments, t = [], [], [], 0
|
||||
stop_flags = [True, False, False]
|
||||
stop_count = 0
|
||||
|
@ -433,8 +468,9 @@ class Decoder(nn.Module):
|
|||
self._init_states(inputs, mask=None, keep_states=True)
|
||||
|
||||
self.attention_layer.init_win_idx()
|
||||
self.attention_layer.init_forward_attn_state()
|
||||
outputs, gate_outputs, alignments, t = [], [], [], 0
|
||||
stop_flags = [False, False]
|
||||
stop_flags = [False, False, False]
|
||||
stop_count = 0
|
||||
while True:
|
||||
memory = self.prenet(self.memory_truncated)
|
||||
|
@ -454,6 +490,7 @@ class Decoder(nn.Module):
|
|||
elif len(outputs) == self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
break
|
||||
|
||||
self.memory_truncated = mel_output
|
||||
t += 1
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from utils.generic_utils import sequence_mask
|
|||
|
||||
# TODO: match function arguments with tacotron
|
||||
class Tacotron2(nn.Module):
|
||||
def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax"):
|
||||
def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", forward_attn=False):
|
||||
super(Tacotron2, self).__init__()
|
||||
self.n_mel_channels = 80
|
||||
self.n_frames_per_step = r
|
||||
|
@ -18,7 +18,7 @@ class Tacotron2(nn.Module):
|
|||
val = sqrt(3.0) * std # uniform bounds for std
|
||||
self.embedding.weight.data.uniform_(-val, val)
|
||||
self.encoder = Encoder(512)
|
||||
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm)
|
||||
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, forward_attn)
|
||||
self.postnet = Postnet(self.n_mel_channels)
|
||||
|
||||
def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):
|
||||
|
|
7
train.py
7
train.py
|
@ -23,7 +23,7 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters,
|
|||
load_config, lr_decay,
|
||||
remove_experiment_folder, save_best_model,
|
||||
save_checkpoint, sequence_mask, weight_decay,
|
||||
set_init_dict, copy_config_file)
|
||||
set_init_dict, copy_config_file, setup_model)
|
||||
from utils.logger import Logger
|
||||
from utils.synthesis import synthesis
|
||||
from utils.text.symbols import phonemes, symbols
|
||||
|
@ -375,7 +375,7 @@ def main(args):
|
|||
init_distributed(args.rank, num_gpus, args.group_id,
|
||||
c.distributed["backend"], c.distributed["url"])
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm)
|
||||
model = setup_model(num_chars, c)
|
||||
|
||||
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
||||
|
||||
|
@ -528,9 +528,6 @@ if __name__ == '__main__':
|
|||
# Conditional imports
|
||||
preprocessor = importlib.import_module('datasets.preprocess')
|
||||
preprocessor = getattr(preprocessor, c.dataset.lower())
|
||||
print(" > Using model: {}".format(c.model))
|
||||
MyModel = importlib.import_module('models.'+c.model.lower())
|
||||
MyModel = getattr(MyModel, c.model)
|
||||
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
|
|
|
@ -8,6 +8,7 @@ import datetime
|
|||
import json
|
||||
import torch
|
||||
import subprocess
|
||||
import importlib
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
from torch.autograd import Variable
|
||||
|
@ -236,4 +237,15 @@ def set_init_dict(model_dict, checkpoint, c):
|
|||
# 4. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
print(" | > {} / {} layers are initialized".format(len(pretrained_dict), len(model_dict)))
|
||||
return model_dict
|
||||
return model_dict
|
||||
|
||||
|
||||
def setup_model(num_chars, c):
|
||||
print(" > Using model: {}".format(c.model))
|
||||
MyModel = importlib.import_module('models.'+c.model.lower())
|
||||
MyModel = getattr(MyModel, c.model)
|
||||
if c.model.lower() == "tacotron":
|
||||
model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm, memory_size=c.memory_size)
|
||||
elif c.model.lower() == "tacotron2":
|
||||
model = MyModel(num_chars=num_chars, r=c.r, attn_norm=c.attention_norm, prenet_type=c.prenet_type, forward_attn=c.use_forward_attn)
|
||||
return model
|
Loading…
Reference in New Issue