TTS/utils/generic_utils.py

300 lines
11 KiB
Python
Raw Normal View History

2018-01-22 09:48:59 +00:00
import os
import sys
import glob
import time
import shutil
import datetime
2018-01-22 16:20:20 +00:00
import json
2018-01-26 10:07:07 +00:00
import torch
2018-04-06 10:53:49 +00:00
import subprocess
2018-01-22 09:48:59 +00:00
import numpy as np
2018-02-26 13:33:54 +00:00
from collections import OrderedDict
2018-04-25 12:38:23 +00:00
from torch.autograd import Variable
2018-01-22 09:48:59 +00:00
2018-01-22 16:20:20 +00:00
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def load_config(config_path):
config = AttrDict()
config.update(json.load(open(config_path, "r")))
2018-04-03 10:24:57 +00:00
return config
2018-01-22 16:20:20 +00:00
2018-04-06 10:53:49 +00:00
def get_commit_hash():
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
2018-04-06 10:55:44 +00:00
try:
subprocess.check_output(['git', 'diff-index', '--quiet', 'HEAD']) # Verify client is clean
except:
raise RuntimeError(" !! Commit before training to get the commit hash.")
2018-04-06 10:53:49 +00:00
commit = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode().strip()
print(' > Git Hash: {}'.format(commit))
return commit
2018-05-11 10:49:55 +00:00
def create_experiment_folder(root_path, model_name, debug):
2018-01-22 09:48:59 +00:00
""" Create a folder with the current date and time """
2018-07-13 13:24:50 +00:00
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
2018-05-11 10:49:55 +00:00
if debug:
commit_hash = 'debug'
else:
2018-05-11 10:49:55 +00:00
commit_hash = get_commit_hash()
2018-05-10 23:14:23 +00:00
output_folder = os.path.join(root_path, date_str + '-' + model_name + '-' + commit_hash)
2018-01-22 09:48:59 +00:00
os.makedirs(output_folder, exist_ok=True)
print(" > Experiment folder: {}".format(output_folder))
return output_folder
def remove_experiment_folder(experiment_path):
"""Check folder if there is a checkpoint, otherwise remove the folder"""
checkpoint_files = glob.glob(experiment_path+"/*.pth.tar")
2018-01-22 16:29:27 +00:00
if len(checkpoint_files) < 1:
2018-01-24 16:17:49 +00:00
if os.path.exists(experiment_path):
shutil.rmtree(experiment_path)
print(" ! Run is removed from {}".format(experiment_path))
2018-01-22 09:48:59 +00:00
else:
print(" ! Run is kept in {}".format(experiment_path))
def copy_config_file(config_file, path):
config_name = os.path.basename(config_file)
out_path = os.path.join(path, config_name)
shutil.copyfile(config_file, out_path)
def _trim_model_state_dict(state_dict):
r"""Remove 'module.' prefix from state dictionary. It is necessary as it
is loded for the next time by model.load_state(). Otherwise, it complains
about the torch.DataParallel()"""
new_state_dict = OrderedDict()
for k, v in state_dict.items():
2018-04-03 10:24:57 +00:00
name = k[7:] # remove `module.`
new_state_dict[name] = v
return new_state_dict
2018-02-21 15:21:44 +00:00
def save_checkpoint(model, optimizer, model_loss, out_path,
2018-02-13 09:45:52 +00:00
current_step, epoch):
2018-02-09 13:39:58 +00:00
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)
2018-07-18 12:31:09 +00:00
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
new_state_dict = _trim_model_state_dict(model.state_dict())
state = {'model': new_state_dict,
2018-02-09 13:39:58 +00:00
'optimizer': optimizer.state_dict(),
'step': current_step,
'epoch': epoch,
2018-02-13 09:45:52 +00:00
'linear_loss': model_loss,
2018-02-09 13:39:58 +00:00
'date': datetime.date.today().strftime("%B %d, %Y")}
torch.save(state, checkpoint_path)
2018-02-13 09:45:52 +00:00
def save_best_model(model, optimizer, model_loss, best_loss, out_path,
current_step, epoch):
2018-02-09 13:39:58 +00:00
if model_loss < best_loss:
new_state_dict = _trim_model_state_dict(model.state_dict())
state = {'model': new_state_dict,
2018-02-13 09:45:52 +00:00
'optimizer': optimizer.state_dict(),
'step': current_step,
'epoch': epoch,
'linear_loss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y")}
2018-02-09 13:39:58 +00:00
best_loss = model_loss
2018-02-13 09:45:52 +00:00
bestmodel_path = 'best_model.pth.tar'
2018-02-09 13:39:58 +00:00
bestmodel_path = os.path.join(out_path, bestmodel_path)
print(" | > Best model saving with loss {0:.2f} : {1:}".format(
2018-04-03 10:24:57 +00:00
model_loss, bestmodel_path))
2018-02-09 13:39:58 +00:00
torch.save(state, bestmodel_path)
return best_loss
2018-01-22 09:48:59 +00:00
def check_update(model, grad_clip, grad_top):
r'''Check model gradient against unexpected jumps and failures'''
skip_flag = False
2018-05-10 23:25:48 +00:00
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
if np.isinf(grad_norm):
print(" | > Gradient is INF !!")
skip_flag = True
elif grad_norm > grad_top:
print(" | > Gradient is above the top limit !!")
skip_flag = True
return grad_norm, skip_flag
def lr_decay(init_lr, global_step, warmup_steps):
r'''from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py'''
2018-04-06 10:53:49 +00:00
warmup_steps = float(warmup_steps)
2018-02-01 16:26:40 +00:00
step = global_step + 1.
lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5,
step**-0.5)
return lr
2018-02-23 14:20:22 +00:00
def mk_decay(init_mk, max_epoch, n_epoch):
return init_mk * ((max_epoch - n_epoch) / max_epoch)
2018-02-23 14:20:22 +00:00
def count_parameters(model):
r"""Count number of trainable parameters in a network"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
def sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.data.max()
batch_size = sequence_length.size(0)
seq_range = torch.arange(0, max_len).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.cuda()
seq_length_expand = (sequence_length.unsqueeze(1)
.expand_as(seq_range_expand))
return seq_range_expand < seq_length_expand
2018-01-22 09:48:59 +00:00
class Progbar(object):
"""Displays a progress bar.
2018-04-06 10:53:49 +00:00
Args:
2018-01-22 09:48:59 +00:00
target: Total number of steps expected, None if unknown.
interval: Minimum visual progress update interval (in seconds).
"""
def __init__(self, target, width=30, verbose=1, interval=0.05):
self.width = width
self.target = target
self.sum_values = {}
self.unique_values = []
self.start = time.time()
self.last_update = 0
self.interval = interval
self.total_width = 0
self.seen_so_far = 0
self.verbose = verbose
self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
sys.stdout.isatty()) or
'ipykernel' in sys.modules)
def update(self, current, values=None, force=False):
"""Updates the progress bar.
# Arguments
current: Index of current step.
values: List of tuples (name, value_for_last_step).
The progress bar will display averages for these values.
force: Whether to force visual progress update.
"""
values = values or []
for k, v in values:
if k not in self.sum_values:
self.sum_values[k] = [v * (current - self.seen_so_far),
current - self.seen_so_far]
self.unique_values.append(k)
else:
self.sum_values[k][0] += v * (current - self.seen_so_far)
self.sum_values[k][1] += (current - self.seen_so_far)
self.seen_so_far = current
now = time.time()
info = ' - %.0fs' % (now - self.start)
if self.verbose == 1:
if (not force and (now - self.last_update) < self.interval and
self.target is not None and current < self.target):
return
prev_total_width = self.total_width
if self._dynamic_display:
sys.stdout.write('\b' * prev_total_width)
sys.stdout.write('\r')
else:
sys.stdout.write('\n')
if self.target is not None:
numdigits = int(np.floor(np.log10(self.target))) + 1
barstr = '%%%dd/%d [' % (numdigits, self.target)
bar = barstr % current
prog = float(current) / self.target
prog_width = int(self.width * prog)
if prog_width > 0:
bar += ('=' * (prog_width - 1))
if current < self.target:
bar += '>'
else:
bar += '='
bar += ('.' * (self.width - prog_width))
bar += ']'
else:
bar = '%7d/Unknown' % current
self.total_width = len(bar)
sys.stdout.write(bar)
if current:
time_per_unit = (now - self.start) / current
else:
time_per_unit = 0
if self.target is not None and current < self.target:
eta = time_per_unit * (self.target - current)
if eta > 3600:
eta_format = '%d:%02d:%02d' % (
eta // 3600, (eta % 3600) // 60, eta % 60)
elif eta > 60:
eta_format = '%d:%02d' % (eta // 60, eta % 60)
else:
eta_format = '%ds' % eta
info = ' - ETA: %s' % eta_format
2018-02-27 14:25:28 +00:00
if time_per_unit >= 1:
info += ' %.0fs/step' % time_per_unit
elif time_per_unit >= 1e-3:
info += ' %.0fms/step' % (time_per_unit * 1e3)
2018-01-22 09:48:59 +00:00
else:
2018-02-27 14:25:28 +00:00
info += ' %.0fus/step' % (time_per_unit * 1e6)
2018-01-22 09:48:59 +00:00
for k in self.unique_values:
info += ' - %s:' % k
if isinstance(self.sum_values[k], list):
avg = np.mean(
self.sum_values[k][0] / max(1, self.sum_values[k][1]))
if abs(avg) > 1e-3:
info += ' %.4f' % avg
else:
info += ' %.4e' % avg
else:
info += ' %s' % self.sum_values[k]
self.total_width += len(info)
if prev_total_width > self.total_width:
info += (' ' * (prev_total_width - self.total_width))
if self.target is not None and current >= self.target:
info += '\n'
sys.stdout.write(info)
sys.stdout.flush()
elif self.verbose == 2:
if self.target is None or current >= self.target:
for k in self.unique_values:
info += ' - %s:' % k
avg = np.mean(
self.sum_values[k][0] / max(1, self.sum_values[k][1]))
if avg > 1e-3:
info += ' %.4f' % avg
else:
info += ' %.4e' % avg
info += '\n'
sys.stdout.write(info)
sys.stdout.flush()
self.last_update = now
def add(self, n, values=None):
self.update(self.seen_so_far + n, values)