mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'dev-gradual-queue' into dev
commit
b22c7d4a29
6
.compute
6
.compute
|
@ -4,13 +4,13 @@ yes | apt-get install ffmpeg
|
|||
yes | apt-get install espeak
|
||||
yes | apt-get install tmux
|
||||
yes | apt-get install zsh
|
||||
pip3 install https://download.pytorch.org/whl/cu100/torch-1.1.0-cp37-cp37m-linux_x86_64.whl
|
||||
# pip3 install https://download.pytorch.org/whl/cu100/torch-1.1.0-cp37-cp37m-linux_x86_64.whl
|
||||
# wget https://www.dropbox.com/s/m8waow6b3ydpf6h/MozillaDataset.tar.gz?dl=0 -O /data/rw/home/mozilla.tar
|
||||
wget https://www.dropbox.com/s/wqn5v3wkktw9lmo/install.sh?dl=0 -O install.sh
|
||||
sudo sh install.sh
|
||||
python3 setup.py develop
|
||||
# cp -R ${USER_DIR}/GermanData ../tmp/
|
||||
python3 distribute.py --config_path config_libritts.json --data_path /data/rw/home/LibriTTS/train-clean-360/
|
||||
python3 distribute.py --config_path config.json --data_path /data/ro/shared/data/keithito/LJSpeech-1.1/
|
||||
# cp -R ${USER_DIR}/Mozilla_22050 ../tmp/
|
||||
# python3 distribute.py --config_path config_tacotron_gst.json --data_path ../tmp/Mozilla_22050/
|
||||
while true; do sleep 1000000; done
|
||||
# while true; do sleep 1000000; done
|
||||
|
|
49
config.json
49
config.json
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"run_name": "mozilla-no-loc-fattn-stopnet-sigmoid-loss_masking",
|
||||
"run_description": "using forward attention, with original prenet, loss masking,separate stopnet, sigmoid. Compare this with 4817. Pytorch DPP",
|
||||
"run_name": "ljspeech",
|
||||
"run_description": "gradual training with prenet frame size 1 + no maxout for cbhg + symmetric norm.",
|
||||
|
||||
"audio":{
|
||||
// Audio processing parameters
|
||||
|
@ -16,8 +16,8 @@
|
|||
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||
// Normalization parameters
|
||||
"signal_norm": true, // normalize the spec values in range [0, 1]
|
||||
"symmetric_norm": false, // move normalization to range [-1, 1]
|
||||
"max_norm": 1, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||
"max_norm": 4, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"mel_fmin": 0.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": 8000.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||
|
@ -31,52 +31,53 @@
|
|||
|
||||
"reinit_layers": [],
|
||||
|
||||
"model": "Tacotron2", // one of the model in models/
|
||||
"model": "Tacotron", // one of the model in models/
|
||||
"grad_clip": 1, // upper limit for gradients for clipping.
|
||||
"epochs": 1000, // total number of epochs to train.
|
||||
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"windowing": false, // Enables attention windowing. Used only in eval mode.
|
||||
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
||||
"memory_size": -1, // ONLY TACOTRON - size of the memory queue used fro storing last decoder predictions for auto-regression. If < 0, memory queue is disabled and decoder only uses the last prediction frame.
|
||||
"attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
||||
"prenet_type": "original", // "original" or "bn".
|
||||
"prenet_dropout": true, // enable/disable dropout at prenet.
|
||||
"use_forward_attn": true, // enable/disable forward attention. In general, it aligns faster.
|
||||
"forward_attn_mask": false, // Apply forward attention mask at inference to prevent bad modes. Try it if your model does not align well.
|
||||
"transition_agent": false, // enable/disable transition agent of forward attention.
|
||||
"location_attn": false, // enable_disable location sensitive attention.
|
||||
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
|
||||
"prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet.
|
||||
"windowing": false, // Enables attention windowing. Used only in eval mode.
|
||||
"use_forward_attn": false, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
|
||||
"forward_attn_mask": false,
|
||||
"transition_agent": false, // ONLY TACOTRON2 - enable/disable transition agent of forward attention.
|
||||
"location_attn": true, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default.
|
||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||
"stopnet": true, // Train stopnet predicting the end of synthesis.
|
||||
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"eval_batch_size":16,
|
||||
"r": 1, // Number of frames to predict for step.
|
||||
"r": 7, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
||||
"gradual_training": [[0, 7, 32], [10000, 5, 32], [50000, 3, 32], [130000, 2, 16], [290000, 1, 8]], // set gradual training steps [first_step, r, batch_size]. If it is null, gradual training is disabled.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
|
||||
"print_step": 10, // Number of steps to log traning on console.
|
||||
"save_step": 10000, // Number of training steps expected to save traning stats and checkpoints.
|
||||
"print_step": 25, // Number of steps to log traning on console.
|
||||
"batch_group_size": 0, //Number of batches to shuffle after bucketing.
|
||||
|
||||
"run_eval": true,
|
||||
"test_delay_epochs": 5, //Until attention is aligned, testing only wastes computation time.
|
||||
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
|
||||
"data_path": "/media/erogol/data_ssd/Data/Mozilla/", // DATASET-RELATED: can overwritten from command argument
|
||||
"meta_file_train": "metadata_train.txt", // DATASET-RELATED: metafile for training dataloader.
|
||||
"meta_file_val": "metadata_val.txt", // DATASET-RELATED: metafile for evaluation dataloader.
|
||||
"dataset": "mozilla", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py
|
||||
"min_seq_len": 0, // DATASET-RELATED: minimum text length to use in training
|
||||
"data_path": "/home/erogol/Data/LJSpeech-1.1/", // DATASET-RELATED: can overwritten from command argument
|
||||
"meta_file_train": "metadata_train.csv", // DATASET-RELATED: metafile for training dataloader.
|
||||
"meta_file_val": "metadata_val.csv", // DATASET-RELATED: metafile for evaluation dataloader.
|
||||
"dataset": "ljspeech", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py
|
||||
"min_seq_len": 6, // DATASET-RELATED: minimum text length to use in training
|
||||
"max_seq_len": 150, // DATASET-RELATED: maximum text length
|
||||
"output_path": "../keep/", // DATASET-RELATED: output path for all training outputs.
|
||||
"output_path": "/media/erogol/data_ssd/Models/libri_tts/", // DATASET-RELATED: output path for all training outputs.
|
||||
"num_loader_workers": 4, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"num_val_loader_workers": 4, // number of evaluation data loader processes.
|
||||
"phoneme_cache_path": "mozilla_us_phonemes", // phoneme computation is slow, therefore, it caches results in the given folder.
|
||||
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronounciation.
|
||||
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||
"text_cleaner": "phoneme_cleaners",
|
||||
"use_speaker_embedding": false // whether to use additional embeddings for separate speakers
|
||||
"use_speaker_embedding": false
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"run_name": "libritts-360",
|
||||
"run_description": "LibriTTS 360 clean with multi speaker embedding.",
|
||||
"run_description": "LibriTTS 360 gradual traning with memory queue.",
|
||||
|
||||
"audio":{
|
||||
// Audio processing parameters
|
||||
|
@ -31,13 +31,13 @@
|
|||
|
||||
"reinit_layers": [],
|
||||
|
||||
"model": "Tacotron2", // one of the model in models/
|
||||
"model": "Tacotron", // one of the model in models/
|
||||
"grad_clip": 1, // upper limit for gradients for clipping.
|
||||
"epochs": 1000, // total number of epochs to train.
|
||||
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
||||
"memory_size": 7, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
||||
"attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
||||
"prenet_type": "original", // "original" or "bn".
|
||||
"prenet_dropout": true, // enable/disable dropout at prenet.
|
||||
|
@ -52,9 +52,9 @@
|
|||
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
|
||||
"tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
|
||||
"batch_size": 24, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||
"eval_batch_size":16,
|
||||
"r": 1, // Number of frames to predict for step.
|
||||
"r": 7, // Number of frames to predict for step.
|
||||
"wd": 0.000001, // Weight decay weight.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"save_step": 1000, // Number of training steps expected to save traning stats and checkpoints.
|
||||
|
|
|
@ -234,7 +234,7 @@ class Attention(nn.Module):
|
|||
query, processed_inputs)
|
||||
# apply masking
|
||||
if mask is not None:
|
||||
attention.data.masked_fill_(1 - mask, self._mask_value)
|
||||
attention.data.masked_fill_(torch.bitwise_not(mask), self._mask_value)
|
||||
# apply windowing - only in eval mode
|
||||
if not self.training and self.windowing:
|
||||
attention = self.apply_windowing(attention, inputs)
|
||||
|
|
|
@ -135,9 +135,6 @@ class CBHG(nn.Module):
|
|||
])
|
||||
# max pooling of conv bank, with padding
|
||||
# TODO: try average pooling OR larger kernel size
|
||||
self.max_pool1d = nn.Sequential(
|
||||
nn.ConstantPad1d([0, 1], value=0),
|
||||
nn.MaxPool1d(kernel_size=2, stride=1, padding=0))
|
||||
out_features = [K * conv_bank_features] + conv_projections[:-1]
|
||||
activations = [self.relu] * (len(conv_projections) - 1)
|
||||
activations += [None]
|
||||
|
@ -186,7 +183,6 @@ class CBHG(nn.Module):
|
|||
outs.append(out)
|
||||
x = torch.cat(outs, dim=1)
|
||||
assert x.size(1) == self.conv_bank_features * len(self.conv1d_banks)
|
||||
x = self.max_pool1d(x)
|
||||
for conv1d in self.conv1d_projections:
|
||||
x = conv1d(x)
|
||||
# (B, T_in, hid_feature)
|
||||
|
@ -270,23 +266,27 @@ class Decoder(nn.Module):
|
|||
memory_size (int): size of the past window. if <= 0 memory_size = r
|
||||
TODO: arguments
|
||||
"""
|
||||
|
||||
# Pylint gets confused by PyTorch conventions here
|
||||
#pylint: disable=attribute-defined-outside-init
|
||||
|
||||
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing,
|
||||
attn_norm, prenet_type, prenet_dropout, forward_attn,
|
||||
trans_agent, forward_attn_mask, location_attn, separate_stopnet):
|
||||
trans_agent, forward_attn_mask, location_attn,
|
||||
separate_stopnet):
|
||||
super(Decoder, self).__init__()
|
||||
self.r_init = r
|
||||
self.r = r
|
||||
self.in_features = in_features
|
||||
self.max_decoder_steps = 500
|
||||
self.use_memory_queue = memory_size > 0
|
||||
self.memory_size = memory_size if memory_size > 0 else r
|
||||
self.memory_dim = memory_dim
|
||||
self.separate_stopnet = separate_stopnet
|
||||
self.query_dim = 256
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
self.prenet = Prenet(
|
||||
memory_dim * self.memory_size,
|
||||
memory_dim * self.memory_size if self.use_memory_queue else memory_dim,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
out_features=[256, 128])
|
||||
|
@ -311,21 +311,12 @@ class Decoder(nn.Module):
|
|||
self.decoder_rnns = nn.ModuleList(
|
||||
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||
# RNN_state -> |Linear| -> mel_spec
|
||||
self.proj_to_mel = nn.Linear(256, memory_dim * r)
|
||||
self.proj_to_mel = nn.Linear(256, memory_dim * self.r_init)
|
||||
# learn init values instead of zero init.
|
||||
self.attention_rnn_init = nn.Embedding(1, 256)
|
||||
self.memory_init = nn.Embedding(1, self.memory_size * memory_dim)
|
||||
self.decoder_rnn_inits = nn.Embedding(2, 256)
|
||||
self.stopnet = StopNet(256 + memory_dim * r)
|
||||
# self.init_layers()
|
||||
self.stopnet = StopNet(256 + memory_dim * self.r_init)
|
||||
|
||||
def init_layers(self):
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.project_to_decoder_in.weight,
|
||||
gain=torch.nn.init.calculate_gain('linear'))
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.proj_to_mel.weight,
|
||||
gain=torch.nn.init.calculate_gain('linear'))
|
||||
def _set_r(self, new_r):
|
||||
self.r = new_r
|
||||
|
||||
def _reshape_memory(self, memory):
|
||||
"""
|
||||
|
@ -347,13 +338,14 @@ class Decoder(nn.Module):
|
|||
B = inputs.size(0)
|
||||
T = inputs.size(1)
|
||||
# go frame as zeros matrix
|
||||
self.memory_input = self.memory_init(inputs.data.new_zeros(B).long())
|
||||
|
||||
if self.use_memory_queue:
|
||||
self.memory_input = torch.zeros(B, self.memory_dim * self.memory_size, device=inputs.device)
|
||||
else:
|
||||
self.memory_input = torch.zeros(B, self.memory_dim, device=inputs.device)
|
||||
# decoder states
|
||||
self.query = self.attention_rnn_init(
|
||||
inputs.data.new_zeros(B).long())
|
||||
self.attention_rnn_hidden = torch.zeros(B, 256, device=inputs.device)
|
||||
self.decoder_rnn_hiddens = [
|
||||
self.decoder_rnn_inits(inputs.data.new_tensor([idx] * B).long())
|
||||
torch.zeros(B, 256, device=inputs.device)
|
||||
for idx in range(len(self.decoder_rnns))
|
||||
]
|
||||
self.context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||
|
@ -370,12 +362,13 @@ class Decoder(nn.Module):
|
|||
def decode(self, inputs, mask=None):
|
||||
# Prenet
|
||||
processed_memory = self.prenet(self.memory_input)
|
||||
|
||||
# Attention
|
||||
self.query = self.attention_rnn(torch.cat((processed_memory, self.context_vec), -1), self.query)
|
||||
self.context_vec = self.attention(self.query, inputs, self.processed_inputs, mask)
|
||||
|
||||
# Concat query and attention context vector
|
||||
# Attention RNN
|
||||
self.attention_rnn_hidden = self.attention_rnn(
|
||||
torch.cat((processed_memory, self.current_context_vec), -1),
|
||||
self.attention_rnn_hidden)
|
||||
self.context_vec = self.attention_layer(
|
||||
self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
|
||||
# Concat RNN output and attention context vector
|
||||
decoder_input = self.project_to_decoder_in(
|
||||
torch.cat((self.query, self.context_vec), -1))
|
||||
|
||||
|
@ -389,25 +382,30 @@ class Decoder(nn.Module):
|
|||
|
||||
# predict mel vectors from decoder vectors
|
||||
output = self.proj_to_mel(decoder_output)
|
||||
output = torch.sigmoid(output)
|
||||
|
||||
# output = torch.sigmoid(output)
|
||||
# predict stop token
|
||||
stopnet_input = torch.cat([decoder_output, output], -1)
|
||||
if self.separate_stopnet:
|
||||
stop_token = self.stopnet(stopnet_input.detach())
|
||||
else:
|
||||
stop_token = self.stopnet(stopnet_input)
|
||||
return output, stop_token, self.attention.attention_weights
|
||||
output = output[:, : self.r * self.memory_dim]
|
||||
return output, stop_token, self.attention_layer.attention_weights
|
||||
|
||||
def _update_memory_queue(self, new_memory):
|
||||
if self.memory_size > 0 and new_memory.shape[-1] < self.memory_size:
|
||||
self.memory_input = torch.cat([
|
||||
self.memory_input[:, self.r * self.memory_dim:].clone(),
|
||||
new_memory
|
||||
],
|
||||
dim=-1)
|
||||
def _update_memory_input(self, new_memory):
|
||||
if self.use_memory_queue:
|
||||
if self.memory_size > self.r:
|
||||
# memory queue size is larger than number of frames per decoder iter
|
||||
self.memory_input = torch.cat([
|
||||
new_memory, self.memory_input[:, :(self.memory_size - self.r) * self.memory_dim].clone()
|
||||
],
|
||||
dim=-1)
|
||||
else:
|
||||
# memory queue size smaller than number of frames per decoder iter
|
||||
self.memory_input = new_memory[:, :self.memory_size * self.memory_dim]
|
||||
else:
|
||||
self.memory_input = new_memory
|
||||
# use only the last frame prediction
|
||||
self.memory_input = new_memory[:, :self.memory_dim]
|
||||
|
||||
def forward(self, inputs, memory, mask):
|
||||
"""
|
||||
|
@ -433,7 +431,7 @@ class Decoder(nn.Module):
|
|||
while len(outputs) < memory.size(0):
|
||||
if t > 0:
|
||||
new_memory = memory[t - 1]
|
||||
self._update_memory_queue(new_memory)
|
||||
self._update_memory_input(new_memory)
|
||||
output, stop_token, attention = self.decode(inputs, mask)
|
||||
outputs += [output]
|
||||
attentions += [attention]
|
||||
|
@ -460,7 +458,7 @@ class Decoder(nn.Module):
|
|||
while True:
|
||||
if t > 0:
|
||||
new_memory = outputs[-1]
|
||||
self._update_memory_queue(new_memory)
|
||||
self._update_memory_input(new_memory)
|
||||
output, stop_token, attention = self.decode(inputs, None)
|
||||
stop_token = torch.sigmoid(stop_token.data)
|
||||
outputs += [output]
|
||||
|
|
|
@ -36,10 +36,8 @@ class Tacotron(nn.Module):
|
|||
forward_attn, trans_agent, forward_attn_mask,
|
||||
location_attn, separate_stopnet)
|
||||
self.postnet = PostCBHG(mel_dim)
|
||||
self.last_linear = nn.Sequential(
|
||||
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
|
||||
nn.Sigmoid())
|
||||
|
||||
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim)
|
||||
|
||||
def forward(self, characters, text_lengths, mel_specs, speaker_ids=None):
|
||||
B = characters.size(0)
|
||||
mask = sequence_mask(text_lengths).to(characters.device)
|
||||
|
|
77
train.py
77
train.py
|
@ -20,7 +20,7 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters,
|
|||
load_config, remove_experiment_folder,
|
||||
save_best_model, save_checkpoint, weight_decay,
|
||||
set_init_dict, copy_config_file, setup_model,
|
||||
split_dataset)
|
||||
split_dataset, gradual_training_scheduler)
|
||||
from utils.logger import Logger
|
||||
from utils.speakers import load_speaker_mapping, save_speaker_mapping, \
|
||||
get_speakers
|
||||
|
@ -82,7 +82,7 @@ def setup_loader(ap, is_val=False, verbose=False):
|
|||
|
||||
|
||||
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||
ap, epoch):
|
||||
ap, global_step, epoch):
|
||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||
if c.use_speaker_embedding:
|
||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||
|
@ -92,8 +92,10 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
avg_decoder_loss = 0
|
||||
avg_stop_loss = 0
|
||||
avg_step_time = 0
|
||||
avg_loader_time = 0
|
||||
print("\n > Epoch {}/{}".format(epoch, c.epochs), flush=True)
|
||||
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
end_time = time.time()
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
|
@ -107,6 +109,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
stop_targets = data[6]
|
||||
avg_text_length = torch.mean(text_lengths.float())
|
||||
avg_spec_length = torch.mean(mel_lengths.float())
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
if c.use_speaker_embedding:
|
||||
speaker_ids = [speaker_mapping[speaker_name]
|
||||
|
@ -120,8 +123,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
stop_targets.size(1) // c.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
||||
|
||||
current_step = num_iter + args.restore_step + \
|
||||
epoch * len(data_loader) + 1
|
||||
global_step += 1
|
||||
|
||||
# setup lr
|
||||
if c.lr_decay:
|
||||
|
@ -176,18 +178,20 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
optimizer_st.step()
|
||||
else:
|
||||
grad_norm_st = 0
|
||||
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
if current_step % c.print_step == 0:
|
||||
if global_step % c.print_step == 0:
|
||||
print(
|
||||
" | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} PostnetLoss:{:.5f} "
|
||||
"DecoderLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
|
||||
"GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} LR:{:.6f}".format(
|
||||
num_iter, batch_n_iter, current_step, loss.item(),
|
||||
"GradNormST:{:.5f} AvgTextLen:{:.1f} AvgSpecLen:{:.1f} StepTime:{:.2f} "
|
||||
"LoaderTime:{:.2f} LR:{:.6f}".format(
|
||||
num_iter, batch_n_iter, global_step, loss.item(),
|
||||
postnet_loss.item(), decoder_loss.item(), stop_loss.item(),
|
||||
grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, current_lr),
|
||||
grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time,
|
||||
loader_time, current_lr),
|
||||
flush=True)
|
||||
|
||||
# aggregate losses from processes
|
||||
|
@ -202,6 +206,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
avg_decoder_loss += float(decoder_loss.item())
|
||||
avg_stop_loss += stop_loss if isinstance(stop_loss, float) else float(stop_loss.item())
|
||||
avg_step_time += step_time
|
||||
avg_loader_time += loader_time
|
||||
|
||||
# Plot Training Iter Stats
|
||||
iter_stats = {"loss_posnet": postnet_loss.item(),
|
||||
|
@ -210,13 +215,13 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
"grad_norm": grad_norm,
|
||||
"grad_norm_st": grad_norm_st,
|
||||
"step_time": step_time}
|
||||
tb_logger.tb_train_iter_stats(current_step, iter_stats)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
|
||||
if current_step % c.save_step == 0:
|
||||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model, optimizer, optimizer_st,
|
||||
postnet_loss.item(), OUT_PATH, current_step,
|
||||
postnet_loss.item(), OUT_PATH, global_step,
|
||||
epoch)
|
||||
|
||||
# Diagnostic visualizations
|
||||
|
@ -229,31 +234,34 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||
"alignment": plot_alignment(align_img)
|
||||
}
|
||||
tb_logger.tb_train_figures(current_step, figures)
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
train_audio = ap.inv_spectrogram(const_spec.T)
|
||||
else:
|
||||
train_audio = ap.inv_mel_spectrogram(const_spec.T)
|
||||
tb_logger.tb_train_audios(current_step,
|
||||
tb_logger.tb_train_audios(global_step,
|
||||
{'TrainAudio': train_audio},
|
||||
c.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
avg_postnet_loss /= (num_iter + 1)
|
||||
avg_decoder_loss /= (num_iter + 1)
|
||||
avg_stop_loss /= (num_iter + 1)
|
||||
avg_total_loss = avg_decoder_loss + avg_postnet_loss + avg_stop_loss
|
||||
avg_step_time /= (num_iter + 1)
|
||||
avg_loader_time /= (num_iter + 1)
|
||||
|
||||
# print epoch stats
|
||||
print(
|
||||
" | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
|
||||
"AvgPostnetLoss:{:.5f} AvgDecoderLoss:{:.5f} "
|
||||
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
|
||||
"AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
|
||||
"AvgStepTime:{:.2f} AvgLoaderTime:{:.2f}".format(global_step, avg_total_loss,
|
||||
avg_postnet_loss, avg_decoder_loss,
|
||||
avg_stop_loss, epoch_time, avg_step_time),
|
||||
avg_stop_loss, epoch_time, avg_step_time,
|
||||
avg_loader_time),
|
||||
flush=True)
|
||||
|
||||
# Plot Epoch Stats
|
||||
|
@ -263,14 +271,13 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
"loss_decoder": avg_decoder_loss,
|
||||
"stop_loss": avg_stop_loss,
|
||||
"epoch_time": epoch_time}
|
||||
tb_logger.tb_train_epoch_stats(current_step, epoch_stats)
|
||||
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
||||
if c.tb_model_param_stats:
|
||||
tb_logger.tb_model_weights(model, current_step)
|
||||
|
||||
return avg_postnet_loss, current_step
|
||||
tb_logger.tb_model_weights(model, global_step)
|
||||
return avg_postnet_loss, global_step
|
||||
|
||||
|
||||
def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
||||
def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||
data_loader = setup_loader(ap, is_val=True)
|
||||
if c.use_speaker_embedding:
|
||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||
|
@ -383,14 +390,14 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
|||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||
"alignment": plot_alignment(align_img)
|
||||
}
|
||||
tb_logger.tb_eval_figures(current_step, eval_figures)
|
||||
tb_logger.tb_eval_figures(global_step, eval_figures)
|
||||
|
||||
# Sample audio
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
eval_audio = ap.inv_spectrogram(const_spec.T)
|
||||
else:
|
||||
eval_audio = ap.inv_mel_spectrogram(const_spec.T)
|
||||
tb_logger.tb_eval_audios(current_step, {"ValAudio": eval_audio}, c.audio["sample_rate"])
|
||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"])
|
||||
|
||||
# compute average losses
|
||||
avg_postnet_loss /= (num_iter + 1)
|
||||
|
@ -401,7 +408,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
|||
epoch_stats = {"loss_postnet": avg_postnet_loss,
|
||||
"loss_decoder": avg_decoder_loss,
|
||||
"stop_loss": avg_stop_loss}
|
||||
tb_logger.tb_eval_stats(current_step, epoch_stats)
|
||||
tb_logger.tb_eval_stats(global_step, epoch_stats)
|
||||
|
||||
if args.rank == 0 and epoch > c.test_delay_epochs:
|
||||
# test sentences
|
||||
|
@ -427,8 +434,8 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
|||
except:
|
||||
print(" !! Error creating Test Sentence -", idx)
|
||||
traceback.print_exc()
|
||||
tb_logger.tb_test_audios(current_step, test_audios, c.audio['sample_rate'])
|
||||
tb_logger.tb_test_figures(current_step, test_figures)
|
||||
tb_logger.tb_test_audios(global_step, test_audios, c.audio['sample_rate'])
|
||||
tb_logger.tb_test_figures(global_step, test_figures)
|
||||
return avg_postnet_loss
|
||||
|
||||
|
||||
|
@ -526,11 +533,19 @@ def main(args): #pylint: disable=redefined-outer-name
|
|||
if 'best_loss' not in locals():
|
||||
best_loss = float('inf')
|
||||
|
||||
global_step = args.restore_step
|
||||
for epoch in range(0, c.epochs):
|
||||
train_loss, current_step = train(model, criterion, criterion_st,
|
||||
# set gradual training
|
||||
if c.gradual_training is not None:
|
||||
r, c.batch_size = gradual_training_scheduler(global_step, c)
|
||||
c.r = r
|
||||
model.decoder._set_r(r)
|
||||
print(" > Number of outputs per iteration:", model.decoder.r)
|
||||
|
||||
train_loss, global_step = train(model, criterion, criterion_st,
|
||||
optimizer, optimizer_st, scheduler,
|
||||
ap, epoch)
|
||||
val_loss = evaluate(model, criterion, criterion_st, ap, current_step, epoch)
|
||||
ap, global_step, epoch)
|
||||
val_loss = evaluate(model, criterion, criterion_st, ap, global_step, epoch)
|
||||
print(
|
||||
" | > Training Loss: {:.5f} Validation Loss: {:.5f}".format(
|
||||
train_loss, val_loss),
|
||||
|
@ -539,7 +554,7 @@ def main(args): #pylint: disable=redefined-outer-name
|
|||
if c.run_eval:
|
||||
target_loss = val_loss
|
||||
best_loss = save_best_model(model, optimizer, target_loss, best_loss,
|
||||
OUT_PATH, current_step, epoch)
|
||||
OUT_PATH, global_step, epoch)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -573,7 +588,7 @@ if __name__ == '__main__':
|
|||
'--output_folder',
|
||||
type=str,
|
||||
default='',
|
||||
help='folder name for traning outputs.'
|
||||
help='folder name for training outputs.'
|
||||
)
|
||||
|
||||
# DISTRUBUTED
|
||||
|
|
|
@ -121,7 +121,8 @@ def save_checkpoint(model, optimizer, optimizer_st, model_loss, out_path,
|
|||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
'linear_loss': model_loss,
|
||||
'date': datetime.date.today().strftime("%B %d, %Y")
|
||||
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||
'r': model.decoder.r
|
||||
}
|
||||
torch.save(state, checkpoint_path)
|
||||
|
||||
|
@ -136,7 +137,8 @@ def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
|||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
'linear_loss': model_loss,
|
||||
'date': datetime.date.today().strftime("%B %d, %Y")
|
||||
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||
'r': model.decoder.r
|
||||
}
|
||||
best_loss = model_loss
|
||||
bestmodel_path = 'best_model.pth.tar'
|
||||
|
@ -305,3 +307,10 @@ def split_dataset(items):
|
|||
else:
|
||||
return items[:eval_split_size], items[eval_split_size:]
|
||||
|
||||
|
||||
def gradual_training_scheduler(global_step, config):
|
||||
new_values = None
|
||||
for values in config.gradual_training:
|
||||
if global_step >= values[0]:
|
||||
new_values = values
|
||||
return new_values[1], new_values[2]
|
Loading…
Reference in New Issue