From d41bcd27315ce427c74b289c73e7334ce6695bef Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Wed, 25 Apr 2018 05:36:00 -0700 Subject: [PATCH] add mk annealing (mk attn loss contribution) --- config.json | 55 +++++++++++++++++++++--------------------- train.py | 21 ++++++---------- utils/generic_utils.py | 18 ++++++++++++++ 3 files changed, 54 insertions(+), 40 deletions(-) diff --git a/config.json b/config.json index 244150b9..27d630a5 100644 --- a/config.json +++ b/config.json @@ -1,33 +1,34 @@ { - "num_mels": 80, - "num_freq": 1025, - "sample_rate": 22050, - "frame_length_ms": 50, - "frame_shift_ms": 12.5, - "preemphasis": 0.97, - "min_level_db": -100, - "ref_level_db": 20, - "embedding_size": 256, - "text_cleaner": "english_cleaners", + "num_mels": 80, + "num_freq": 1025, + "sample_rate": 22050, + "frame_length_ms": 50, + "frame_shift_ms": 12.5, + "preemphasis": 0.97, + "min_level_db": -100, + "ref_level_db": 20, + "embedding_size": 256, + "text_cleaner": "english_cleaners", - "epochs": 500, - "lr": 0.002, - "warmup_steps": 4000, - "batch_size": 32, - "eval_batch_size":32, - "r": 5, + "epochs": 500, + "lr": 0.002, + "warmup_steps": 4000, + "batch_size": 32, + "eval_batch_size":32, + "r": 5, + "mk": 1, - "griffin_lim_iters": 60, - "power": 1.2, + "griffin_lim_iters": 60, + "power": 1.2, - "dataset": "LJSpeech", - "meta_file_train": "metadata_train.csv", - "meta_file_val": "metadata_val.csv", - "data_path": "/data/shared/KeithIto/LJSpeech-1.0/", - "min_seq_len": 0, - "num_loader_workers": 8, + "dataset": "LJSpeech", + "meta_file_train": "metadata_train.csv", + "meta_file_val": "metadata_val.csv", + "data_path": "/data/shared/KeithIto/LJSpeech-1.0/", + "min_seq_len": 0, + "num_loader_workers": 8, - "checkpoint": true, - "save_step": 908, - "output_path": "/data/shared/erogol_models/" + "checkpoint": true, + "save_step": 908, + "output_path": "/data/shared/erogol_models/" } diff --git a/train.py b/train.py index 1e3a0a59..4108f115 100644 --- a/train.py +++ b/train.py @@ -19,7 +19,8 @@ from tensorboardX import SummaryWriter from utils.generic_utils import (Progbar, remove_experiment_folder, create_experiment_folder, save_checkpoint, save_best_model, load_config, lr_decay, - count_parameters, check_update, get_commit_hash) + count_parameters, check_update, get_commit_hash, + create_attn_mask) from utils.model import get_param_size from utils.visual import plot_alignment, plot_spectrogram from models.tacotron import Tacotron @@ -90,6 +91,9 @@ def train(model, criterion, data_loader, optimizer, epoch): params_group['lr'] = current_lr optimizer.zero_grad() + + # setup mk + mk = mk_decay(c.mk, c.epochs, epoch) # convert inputs to variables text_input_var = Variable(text_input) @@ -105,19 +109,10 @@ def train(model, criterion, data_loader, optimizer, epoch): linear_spec_var = linear_spec_var.cuda() # create attention mask - # TODO: vectorize N = text_input_var.shape[1] T = mel_spec_var.shape[1] // c.r - M = np.zeros([N, T]) - for t in range(T): - for n in range(N): - val = 20 * np.exp(-pow((n/N)-(t/T), 2.0)/0.05) - M[n, t] = val - e_x = np.exp(M - np.max(M)) - M = e_x / e_x.sum(axis=0) # only difference - M = Variable(torch.FloatTensor(M).t()).cuda() - M = torch.stack([M]*32) - + M = create_attn_mask(N, T, g) + # forward pass mel_output, linear_output, alignments =\ model.forward(text_input_var, mel_spec_var) @@ -129,7 +124,7 @@ def train(model, criterion, data_loader, optimizer, epoch): linear_spec_var[:, :, :n_priority_freq], mel_lengths_var) attention_loss = criterion(alignments, M, mel_lengths_var) - loss = mel_loss + linear_loss + attention_loss + loss = mel_loss + linear_loss + mk * attention_loss # backpass and check the grad norm loss.backward() diff --git a/utils/generic_utils.py b/utils/generic_utils.py index d42fcde4..33508f9e 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -131,6 +131,24 @@ def lr_decay(init_lr, global_step, warmup_steps): return lr +def create_attn_mask(N, T, g=0.05): + r'''creating attn mask for guided attention''' + M = np.zeros([N, T]) + for t in range(T): + for n in range(N): + val = 20 * np.exp(-pow((n/N)-(t/T), 2.0)/g) + M[n, t] = val + e_x = np.exp(M - np.max(M)) + M = e_x / e_x.sum(axis=0) # only difference + M = Variable(torch.FloatTensor(M).t()).cuda() + M = torch.stack([M]*32) + return M + + +def mk_decay(init_mk, max_epoch, n_epoch): + return init_mk * ((max_epoch - n_epoch) / max_epoch) + + def count_parameters(model): r"""Count number of trainable parameters in a network""" return sum(p.numel() for p in model.parameters() if p.requires_grad)