2018-01-22 09:48:59 +00:00
|
|
|
import os
|
2018-11-02 15:13:51 +00:00
|
|
|
import re
|
2018-01-22 09:48:59 +00:00
|
|
|
import sys
|
|
|
|
import glob
|
|
|
|
import time
|
|
|
|
import shutil
|
|
|
|
import datetime
|
2018-01-22 16:20:20 +00:00
|
|
|
import json
|
2018-01-26 10:07:07 +00:00
|
|
|
import torch
|
2018-04-06 10:53:49 +00:00
|
|
|
import subprocess
|
2019-04-05 15:49:18 +00:00
|
|
|
import importlib
|
2018-01-22 09:48:59 +00:00
|
|
|
import numpy as np
|
2018-02-26 13:33:54 +00:00
|
|
|
from collections import OrderedDict
|
2018-04-25 12:38:23 +00:00
|
|
|
from torch.autograd import Variable
|
2018-11-05 13:05:04 +00:00
|
|
|
from utils.text import text_to_sequence
|
2018-01-22 09:48:59 +00:00
|
|
|
|
|
|
|
|
2018-01-22 16:20:20 +00:00
|
|
|
class AttrDict(dict):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
super(AttrDict, self).__init__(*args, **kwargs)
|
|
|
|
self.__dict__ = self
|
|
|
|
|
|
|
|
|
|
|
|
def load_config(config_path):
|
|
|
|
config = AttrDict()
|
2018-11-02 15:13:51 +00:00
|
|
|
with open(config_path, "r") as f:
|
|
|
|
input_str = f.read()
|
|
|
|
input_str = re.sub(r'\\\n', '', input_str)
|
|
|
|
input_str = re.sub(r'//.*\n', '\n', input_str)
|
|
|
|
data = json.loads(input_str)
|
|
|
|
config.update(data)
|
2018-04-03 10:24:57 +00:00
|
|
|
return config
|
2018-01-22 16:20:20 +00:00
|
|
|
|
|
|
|
|
2019-03-29 16:01:08 +00:00
|
|
|
def get_git_branch():
|
2019-06-26 10:59:14 +00:00
|
|
|
try:
|
|
|
|
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
|
|
|
current = next(line for line in out.split("\n") if line.startswith("*"))
|
|
|
|
current.replace("* ", "")
|
|
|
|
except subprocess.CalledProcessError:
|
|
|
|
current = "inside_docker"
|
|
|
|
return current
|
2019-03-29 16:01:08 +00:00
|
|
|
|
|
|
|
|
2018-04-06 10:53:49 +00:00
|
|
|
def get_commit_hash():
|
|
|
|
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
|
2018-11-02 15:13:51 +00:00
|
|
|
# try:
|
|
|
|
# subprocess.check_output(['git', 'diff-index', '--quiet',
|
|
|
|
# 'HEAD']) # Verify client is clean
|
|
|
|
# except:
|
|
|
|
# raise RuntimeError(
|
|
|
|
# " !! Commit before training to get the commit hash.")
|
2019-06-26 10:59:14 +00:00
|
|
|
try:
|
|
|
|
commit = subprocess.check_output(['git', 'rev-parse', '--short',
|
|
|
|
'HEAD']).decode().strip()
|
|
|
|
# Not copying .git folder into docker container
|
|
|
|
except subprocess.CalledProcessError:
|
|
|
|
commit = "0000000"
|
2018-04-06 10:53:49 +00:00
|
|
|
print(' > Git Hash: {}'.format(commit))
|
|
|
|
return commit
|
|
|
|
|
|
|
|
|
2018-05-11 10:49:55 +00:00
|
|
|
def create_experiment_folder(root_path, model_name, debug):
|
2018-01-22 09:48:59 +00:00
|
|
|
""" Create a folder with the current date and time """
|
2018-07-13 13:24:50 +00:00
|
|
|
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
2019-03-06 12:11:22 +00:00
|
|
|
# if debug:
|
2019-04-10 14:41:08 +00:00
|
|
|
# commit_hash = 'debug'
|
2019-03-06 12:11:22 +00:00
|
|
|
# else:
|
|
|
|
commit_hash = get_commit_hash()
|
2018-08-02 14:34:17 +00:00
|
|
|
output_folder = os.path.join(
|
2018-12-17 15:38:12 +00:00
|
|
|
root_path, model_name + '-' + date_str + '-' + commit_hash)
|
2018-01-22 09:48:59 +00:00
|
|
|
os.makedirs(output_folder, exist_ok=True)
|
|
|
|
print(" > Experiment folder: {}".format(output_folder))
|
|
|
|
return output_folder
|
|
|
|
|
|
|
|
|
|
|
|
def remove_experiment_folder(experiment_path):
|
|
|
|
"""Check folder if there is a checkpoint, otherwise remove the folder"""
|
|
|
|
|
2018-08-02 14:34:17 +00:00
|
|
|
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
|
2018-01-22 16:29:27 +00:00
|
|
|
if len(checkpoint_files) < 1:
|
2018-01-24 16:17:49 +00:00
|
|
|
if os.path.exists(experiment_path):
|
|
|
|
shutil.rmtree(experiment_path)
|
|
|
|
print(" ! Run is removed from {}".format(experiment_path))
|
2018-01-22 09:48:59 +00:00
|
|
|
else:
|
|
|
|
print(" ! Run is kept in {}".format(experiment_path))
|
|
|
|
|
|
|
|
|
2019-03-29 16:01:08 +00:00
|
|
|
def copy_config_file(config_file, out_path, new_fields):
|
2018-01-22 09:48:59 +00:00
|
|
|
config_name = os.path.basename(config_file)
|
2019-03-29 16:01:08 +00:00
|
|
|
config_lines = open(config_file, "r").readlines()
|
|
|
|
# add extra information fields
|
|
|
|
for key, value in new_fields.items():
|
|
|
|
if type(value) == str:
|
|
|
|
new_line = '"{}":"{}",\n'.format(key, value)
|
|
|
|
else:
|
|
|
|
new_line = '"{}":{},\n'.format(key, value)
|
|
|
|
config_lines.insert(1, new_line)
|
|
|
|
config_out_file = open(out_path, "w")
|
|
|
|
config_out_file.writelines(config_lines)
|
|
|
|
config_out_file.close()
|
2018-01-22 09:48:59 +00:00
|
|
|
|
|
|
|
|
2018-02-21 15:03:53 +00:00
|
|
|
def _trim_model_state_dict(state_dict):
|
|
|
|
r"""Remove 'module.' prefix from state dictionary. It is necessary as it
|
|
|
|
is loded for the next time by model.load_state(). Otherwise, it complains
|
|
|
|
about the torch.DataParallel()"""
|
|
|
|
|
|
|
|
new_state_dict = OrderedDict()
|
|
|
|
for k, v in state_dict.items():
|
2018-04-03 10:24:57 +00:00
|
|
|
name = k[7:] # remove `module.`
|
2018-02-21 15:03:53 +00:00
|
|
|
new_state_dict[name] = v
|
|
|
|
return new_state_dict
|
|
|
|
|
|
|
|
|
2018-07-20 10:23:44 +00:00
|
|
|
def save_checkpoint(model, optimizer, optimizer_st, model_loss, out_path,
|
2018-02-13 09:45:52 +00:00
|
|
|
current_step, epoch):
|
2018-02-09 13:39:58 +00:00
|
|
|
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
|
|
|
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
2018-07-18 12:31:09 +00:00
|
|
|
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
|
2018-02-21 15:03:53 +00:00
|
|
|
|
2018-08-10 15:49:21 +00:00
|
|
|
new_state_dict = model.state_dict()
|
2018-08-02 14:34:17 +00:00
|
|
|
state = {
|
|
|
|
'model': new_state_dict,
|
|
|
|
'optimizer': optimizer.state_dict(),
|
|
|
|
'step': current_step,
|
|
|
|
'epoch': epoch,
|
|
|
|
'linear_loss': model_loss,
|
|
|
|
'date': datetime.date.today().strftime("%B %d, %Y")
|
|
|
|
}
|
2018-02-09 13:39:58 +00:00
|
|
|
torch.save(state, checkpoint_path)
|
2018-02-13 09:45:52 +00:00
|
|
|
|
|
|
|
|
|
|
|
def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
|
|
|
current_step, epoch):
|
2018-02-09 13:39:58 +00:00
|
|
|
if model_loss < best_loss:
|
2018-08-10 15:49:21 +00:00
|
|
|
new_state_dict = model.state_dict()
|
2018-08-02 14:34:17 +00:00
|
|
|
state = {
|
|
|
|
'model': new_state_dict,
|
|
|
|
'optimizer': optimizer.state_dict(),
|
|
|
|
'step': current_step,
|
|
|
|
'epoch': epoch,
|
|
|
|
'linear_loss': model_loss,
|
|
|
|
'date': datetime.date.today().strftime("%B %d, %Y")
|
|
|
|
}
|
2018-02-09 13:39:58 +00:00
|
|
|
best_loss = model_loss
|
2018-02-13 09:45:52 +00:00
|
|
|
bestmodel_path = 'best_model.pth.tar'
|
2018-02-09 13:39:58 +00:00
|
|
|
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
2019-02-27 08:50:52 +00:00
|
|
|
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(
|
2018-04-03 10:24:57 +00:00
|
|
|
model_loss, bestmodel_path))
|
2018-02-09 13:39:58 +00:00
|
|
|
torch.save(state, bestmodel_path)
|
|
|
|
return best_loss
|
2018-01-22 09:48:59 +00:00
|
|
|
|
|
|
|
|
2018-08-10 15:49:48 +00:00
|
|
|
def check_update(model, grad_clip):
|
2018-02-27 15:31:07 +00:00
|
|
|
r'''Check model gradient against unexpected jumps and failures'''
|
|
|
|
skip_flag = False
|
2018-05-10 23:25:48 +00:00
|
|
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
|
2018-02-27 15:31:07 +00:00
|
|
|
if np.isinf(grad_norm):
|
|
|
|
print(" | > Gradient is INF !!")
|
|
|
|
skip_flag = True
|
|
|
|
return grad_norm, skip_flag
|
|
|
|
|
|
|
|
|
|
|
|
def lr_decay(init_lr, global_step, warmup_steps):
|
|
|
|
r'''from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py'''
|
2018-04-06 10:53:49 +00:00
|
|
|
warmup_steps = float(warmup_steps)
|
2018-02-01 16:26:40 +00:00
|
|
|
step = global_step + 1.
|
|
|
|
lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5,
|
|
|
|
step**-0.5)
|
|
|
|
return lr
|
|
|
|
|
2018-02-23 14:20:22 +00:00
|
|
|
|
2019-02-27 08:50:52 +00:00
|
|
|
def weight_decay(optimizer, wd):
|
|
|
|
"""
|
|
|
|
Custom weight decay operation, not effecting grad values.
|
|
|
|
"""
|
|
|
|
for group in optimizer.param_groups:
|
|
|
|
for param in group['params']:
|
|
|
|
current_lr = group['lr']
|
|
|
|
param.data = param.data.add(-wd * group['lr'], param.data)
|
|
|
|
return optimizer, current_lr
|
|
|
|
|
|
|
|
|
2018-11-26 13:09:42 +00:00
|
|
|
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
|
2018-08-13 11:13:45 +00:00
|
|
|
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
|
2018-08-12 13:02:06 +00:00
|
|
|
self.warmup_steps = float(warmup_steps)
|
2018-11-26 13:09:42 +00:00
|
|
|
super(NoamLR, self).__init__(optimizer, last_epoch)
|
2018-08-12 13:02:06 +00:00
|
|
|
|
|
|
|
def get_lr(self):
|
2018-08-13 11:13:45 +00:00
|
|
|
step = max(self.last_epoch, 1)
|
2018-08-12 13:02:06 +00:00
|
|
|
return [
|
2018-08-13 11:13:45 +00:00
|
|
|
base_lr * self.warmup_steps**0.5 * min(
|
|
|
|
step * self.warmup_steps**-1.5, step**-0.5)
|
|
|
|
for base_lr in self.base_lrs
|
2018-08-12 13:02:06 +00:00
|
|
|
]
|
|
|
|
|
|
|
|
|
2018-04-25 12:36:00 +00:00
|
|
|
def mk_decay(init_mk, max_epoch, n_epoch):
|
|
|
|
return init_mk * ((max_epoch - n_epoch) / max_epoch)
|
|
|
|
|
|
|
|
|
2018-02-23 14:20:22 +00:00
|
|
|
def count_parameters(model):
|
|
|
|
r"""Count number of trainable parameters in a network"""
|
|
|
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
|
|
|
|
|
2018-07-13 12:56:05 +00:00
|
|
|
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
|
|
|
def sequence_mask(sequence_length, max_len=None):
|
|
|
|
if max_len is None:
|
|
|
|
max_len = sequence_length.data.max()
|
|
|
|
batch_size = sequence_length.size(0)
|
|
|
|
seq_range = torch.arange(0, max_len).long()
|
|
|
|
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
|
|
|
if sequence_length.is_cuda:
|
|
|
|
seq_range_expand = seq_range_expand.cuda()
|
2019-04-10 14:41:08 +00:00
|
|
|
seq_length_expand = (
|
|
|
|
sequence_length.unsqueeze(1).expand_as(seq_range_expand))
|
2019-02-25 16:20:05 +00:00
|
|
|
# B x T_max
|
2018-07-13 12:56:05 +00:00
|
|
|
return seq_range_expand < seq_length_expand
|
2019-03-23 16:19:40 +00:00
|
|
|
|
|
|
|
|
|
|
|
def set_init_dict(model_dict, checkpoint, c):
|
|
|
|
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
2019-04-08 17:32:07 +00:00
|
|
|
for k, v in checkpoint['model'].items():
|
|
|
|
if k not in model_dict:
|
|
|
|
print(" | > Layer missing in the model definition: {}".format(k))
|
2019-03-23 16:19:40 +00:00
|
|
|
# 1. filter out unnecessary keys
|
|
|
|
pretrained_dict = {
|
|
|
|
k: v
|
|
|
|
for k, v in checkpoint['model'].items() if k in model_dict
|
|
|
|
}
|
|
|
|
# 2. filter out different size layers
|
|
|
|
pretrained_dict = {
|
|
|
|
k: v
|
|
|
|
for k, v in pretrained_dict.items()
|
|
|
|
if v.numel() == model_dict[k].numel()
|
|
|
|
}
|
|
|
|
# 3. skip reinit layers
|
|
|
|
if c.reinit_layers is not None:
|
|
|
|
for reinit_layer_name in c.reinit_layers:
|
|
|
|
pretrained_dict = {
|
|
|
|
k: v
|
|
|
|
for k, v in pretrained_dict.items()
|
|
|
|
if reinit_layer_name not in k
|
|
|
|
}
|
|
|
|
# 4. overwrite entries in the existing state dict
|
|
|
|
model_dict.update(pretrained_dict)
|
2019-04-10 14:41:08 +00:00
|
|
|
print(" | > {} / {} layers are restored.".format(
|
|
|
|
len(pretrained_dict), len(model_dict)))
|
2019-04-05 15:49:18 +00:00
|
|
|
return model_dict
|
|
|
|
|
|
|
|
|
|
|
|
def setup_model(num_chars, c):
|
|
|
|
print(" > Using model: {}".format(c.model))
|
2019-04-10 14:41:08 +00:00
|
|
|
MyModel = importlib.import_module('models.' + c.model.lower())
|
2019-04-05 15:49:18 +00:00
|
|
|
MyModel = getattr(MyModel, c.model)
|
2019-06-05 16:33:57 +00:00
|
|
|
if c.model.lower() in ["tacotron", "tacotrongst"]:
|
2019-04-10 14:41:08 +00:00
|
|
|
model = MyModel(
|
|
|
|
num_chars=num_chars,
|
2019-06-26 10:59:14 +00:00
|
|
|
num_speakers=c.num_speakers,
|
2019-04-10 14:41:08 +00:00
|
|
|
r=c.r,
|
2019-05-27 12:40:28 +00:00
|
|
|
linear_dim=1025,
|
|
|
|
mel_dim=80,
|
|
|
|
memory_size=c.memory_size,
|
2019-04-12 14:13:40 +00:00
|
|
|
attn_win=c.windowing,
|
2019-04-10 14:41:08 +00:00
|
|
|
attn_norm=c.attention_norm,
|
2019-05-27 12:40:28 +00:00
|
|
|
prenet_type=c.prenet_type,
|
|
|
|
prenet_dropout=c.prenet_dropout,
|
|
|
|
forward_attn=c.use_forward_attn,
|
|
|
|
trans_agent=c.transition_agent,
|
2019-06-03 22:39:29 +00:00
|
|
|
forward_attn_mask=c.forward_attn_mask,
|
2019-05-27 12:40:28 +00:00
|
|
|
location_attn=c.location_attn,
|
2019-05-17 14:15:43 +00:00
|
|
|
separate_stopnet=c.separate_stopnet)
|
2019-04-05 15:49:18 +00:00
|
|
|
elif c.model.lower() == "tacotron2":
|
2019-04-10 14:41:08 +00:00
|
|
|
model = MyModel(
|
|
|
|
num_chars=num_chars,
|
2019-06-26 10:59:14 +00:00
|
|
|
num_speakers=c.num_speakers,
|
2019-04-10 14:41:08 +00:00
|
|
|
r=c.r,
|
2019-04-12 14:13:40 +00:00
|
|
|
attn_win=c.windowing,
|
2019-04-10 14:41:08 +00:00
|
|
|
attn_norm=c.attention_norm,
|
|
|
|
prenet_type=c.prenet_type,
|
2019-05-12 15:35:31 +00:00
|
|
|
prenet_dropout=c.prenet_dropout,
|
2019-04-10 14:41:08 +00:00
|
|
|
forward_attn=c.use_forward_attn,
|
2019-04-29 09:37:01 +00:00
|
|
|
trans_agent=c.transition_agent,
|
2019-06-03 22:39:29 +00:00
|
|
|
forward_attn_mask=c.forward_attn_mask,
|
2019-05-17 14:15:43 +00:00
|
|
|
location_attn=c.location_attn,
|
|
|
|
separate_stopnet=c.separate_stopnet)
|
2019-06-26 10:59:14 +00:00
|
|
|
return model
|