TTS/utils/generic_utils.py

510 lines
21 KiB
Python
Raw Normal View History

2018-01-22 09:48:59 +00:00
import os
2018-11-02 15:13:51 +00:00
import re
2018-01-22 09:48:59 +00:00
import glob
import shutil
import datetime
2018-01-22 16:20:20 +00:00
import json
2018-01-26 10:07:07 +00:00
import torch
2018-04-06 10:53:49 +00:00
import subprocess
import importlib
2018-01-22 09:48:59 +00:00
import numpy as np
2019-07-16 19:15:04 +00:00
from collections import OrderedDict, Counter
2018-01-22 09:48:59 +00:00
2018-01-22 16:20:20 +00:00
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def load_config(config_path):
config = AttrDict()
2018-11-02 15:13:51 +00:00
with open(config_path, "r") as f:
input_str = f.read()
input_str = re.sub(r'\\\n', '', input_str)
input_str = re.sub(r'//.*\n', '\n', input_str)
data = json.loads(input_str)
config.update(data)
2018-04-03 10:24:57 +00:00
return config
2018-01-22 16:20:20 +00:00
def get_git_branch():
2019-06-26 10:59:14 +00:00
try:
out = subprocess.check_output(["git", "branch"]).decode("utf8")
current = next(line for line in out.split("\n")
if line.startswith("*"))
2019-06-26 10:59:14 +00:00
current.replace("* ", "")
except subprocess.CalledProcessError:
current = "inside_docker"
return current
2018-04-06 10:53:49 +00:00
def get_commit_hash():
"""https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script"""
2018-11-02 15:13:51 +00:00
# try:
# subprocess.check_output(['git', 'diff-index', '--quiet',
# 'HEAD']) # Verify client is clean
# except:
# raise RuntimeError(
# " !! Commit before training to get the commit hash.")
2019-06-26 10:59:14 +00:00
try:
commit = subprocess.check_output(
['git', 'rev-parse', '--short', 'HEAD']).decode().strip()
2019-06-26 10:59:14 +00:00
# Not copying .git folder into docker container
except subprocess.CalledProcessError:
commit = "0000000"
2018-04-06 10:53:49 +00:00
print(' > Git Hash: {}'.format(commit))
return commit
2018-05-11 10:49:55 +00:00
def create_experiment_folder(root_path, model_name, debug):
2018-01-22 09:48:59 +00:00
""" Create a folder with the current date and time """
2018-07-13 13:24:50 +00:00
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
# if debug:
2019-04-10 14:41:08 +00:00
# commit_hash = 'debug'
# else:
commit_hash = get_commit_hash()
2018-08-02 14:34:17 +00:00
output_folder = os.path.join(
root_path, model_name + '-' + date_str + '-' + commit_hash)
2018-01-22 09:48:59 +00:00
os.makedirs(output_folder, exist_ok=True)
print(" > Experiment folder: {}".format(output_folder))
return output_folder
def remove_experiment_folder(experiment_path):
"""Check folder if there is a checkpoint, otherwise remove the folder"""
2018-08-02 14:34:17 +00:00
checkpoint_files = glob.glob(experiment_path + "/*.pth.tar")
2019-07-19 06:46:23 +00:00
if not checkpoint_files:
2018-01-24 16:17:49 +00:00
if os.path.exists(experiment_path):
shutil.rmtree(experiment_path)
print(" ! Run is removed from {}".format(experiment_path))
2018-01-22 09:48:59 +00:00
else:
print(" ! Run is kept in {}".format(experiment_path))
def copy_config_file(config_file, out_path, new_fields):
config_lines = open(config_file, "r").readlines()
# add extra information fields
for key, value in new_fields.items():
if type(value) == str:
new_line = '"{}":"{}",\n'.format(key, value)
else:
new_line = '"{}":{},\n'.format(key, value)
config_lines.insert(1, new_line)
config_out_file = open(out_path, "w")
config_out_file.writelines(config_lines)
config_out_file.close()
2018-01-22 09:48:59 +00:00
def _trim_model_state_dict(state_dict):
r"""Remove 'module.' prefix from state dictionary. It is necessary as it
is loded for the next time by model.load_state(). Otherwise, it complains
about the torch.DataParallel()"""
new_state_dict = OrderedDict()
for k, v in state_dict.items():
2018-04-03 10:24:57 +00:00
name = k[7:] # remove `module.`
new_state_dict[name] = v
return new_state_dict
2018-07-20 10:23:44 +00:00
def save_checkpoint(model, optimizer, optimizer_st, model_loss, out_path,
2018-02-13 09:45:52 +00:00
current_step, epoch):
2018-02-09 13:39:58 +00:00
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)
2018-07-18 12:31:09 +00:00
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
new_state_dict = model.state_dict()
2018-08-02 14:34:17 +00:00
state = {
'model': new_state_dict,
2019-07-22 13:44:09 +00:00
'optimizer': optimizer.state_dict() if optimizer is not None else None,
2018-08-02 14:34:17 +00:00
'step': current_step,
'epoch': epoch,
'linear_loss': model_loss,
2019-08-16 11:11:51 +00:00
'date': datetime.date.today().strftime("%B %d, %Y"),
'r': model.decoder.r
2018-08-02 14:34:17 +00:00
}
2018-02-09 13:39:58 +00:00
torch.save(state, checkpoint_path)
2018-02-13 09:45:52 +00:00
def save_best_model(model, optimizer, model_loss, best_loss, out_path,
current_step, epoch):
2018-02-09 13:39:58 +00:00
if model_loss < best_loss:
new_state_dict = model.state_dict()
2018-08-02 14:34:17 +00:00
state = {
'model': new_state_dict,
'optimizer': optimizer.state_dict(),
'step': current_step,
'epoch': epoch,
'linear_loss': model_loss,
2019-08-16 11:11:51 +00:00
'date': datetime.date.today().strftime("%B %d, %Y"),
'r': model.decoder.r
2018-08-02 14:34:17 +00:00
}
2018-02-09 13:39:58 +00:00
best_loss = model_loss
2018-02-13 09:45:52 +00:00
bestmodel_path = 'best_model.pth.tar'
2018-02-09 13:39:58 +00:00
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(
2018-04-03 10:24:57 +00:00
model_loss, bestmodel_path))
2018-02-09 13:39:58 +00:00
torch.save(state, bestmodel_path)
return best_loss
2018-01-22 09:48:59 +00:00
def check_update(model, grad_clip, ignore_stopnet=False):
r'''Check model gradient against unexpected jumps and failures'''
skip_flag = False
if ignore_stopnet:
grad_norm = torch.nn.utils.clip_grad_norm_([param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
if np.isinf(grad_norm):
print(" | > Gradient is INF !!")
skip_flag = True
return grad_norm, skip_flag
def lr_decay(init_lr, global_step, warmup_steps):
r'''from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py'''
2018-04-06 10:53:49 +00:00
warmup_steps = float(warmup_steps)
2018-02-01 16:26:40 +00:00
step = global_step + 1.
lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5,
step**-0.5)
return lr
2018-02-23 14:20:22 +00:00
def adam_weight_decay(optimizer):
"""
Custom weight decay operation, not effecting grad values.
"""
for group in optimizer.param_groups:
for param in group['params']:
current_lr = group['lr']
weight_decay = group['weight_decay']
param.data = param.data.add(-weight_decay * group['lr'],
param.data)
return optimizer, current_lr
# pylint: disable=dangerous-default-value
2019-09-28 13:31:18 +00:00
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}):
"""
2019-09-28 13:31:18 +00:00
Skip biases, BatchNorm parameters, rnns.
and attention projection layer v
"""
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
2019-09-28 13:31:18 +00:00
if len(param.shape) == 1 or any([skip_name in name for skip_name in skip_list]):
no_decay.append(param)
else:
decay.append(param)
return [{
'params': no_decay,
'weight_decay': 0.
}, {
'params': decay,
'weight_decay': weight_decay
}]
2018-11-26 13:09:42 +00:00
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
2018-08-12 13:02:06 +00:00
self.warmup_steps = float(warmup_steps)
2018-11-26 13:09:42 +00:00
super(NoamLR, self).__init__(optimizer, last_epoch)
2018-08-12 13:02:06 +00:00
def get_lr(self):
step = max(self.last_epoch, 1)
2018-08-12 13:02:06 +00:00
return [
base_lr * self.warmup_steps**0.5 *
min(step * self.warmup_steps**-1.5, step**-0.5)
for base_lr in self.base_lrs
2018-08-12 13:02:06 +00:00
]
def mk_decay(init_mk, max_epoch, n_epoch):
return init_mk * ((max_epoch - n_epoch) / max_epoch)
2018-02-23 14:20:22 +00:00
def count_parameters(model):
r"""Count number of trainable parameters in a network"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
def sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.data.max()
batch_size = sequence_length.size(0)
seq_range = torch.arange(0, max_len).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.cuda()
2019-04-10 14:41:08 +00:00
seq_length_expand = (
sequence_length.unsqueeze(1).expand_as(seq_range_expand))
2019-02-25 16:20:05 +00:00
# B x T_max
return seq_range_expand < seq_length_expand
def set_init_dict(model_dict, checkpoint, c):
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
for k, v in checkpoint['model'].items():
if k not in model_dict:
print(" | > Layer missing in the model definition: {}".format(k))
# 1. filter out unnecessary keys
pretrained_dict = {
k: v
for k, v in checkpoint['model'].items() if k in model_dict
}
# 2. filter out different size layers
pretrained_dict = {
k: v
for k, v in pretrained_dict.items()
if v.numel() == model_dict[k].numel()
}
# 3. skip reinit layers
if c.reinit_layers is not None:
for reinit_layer_name in c.reinit_layers:
pretrained_dict = {
k: v
for k, v in pretrained_dict.items()
if reinit_layer_name not in k
}
# 4. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
print(" | > {} / {} layers are restored.".format(len(pretrained_dict),
len(model_dict)))
return model_dict
2019-07-10 16:38:55 +00:00
def setup_model(num_chars, num_speakers, c):
print(" > Using model: {}".format(c.model))
2019-08-29 10:11:31 +00:00
MyModel = importlib.import_module('TTS.models.' + c.model.lower())
MyModel = getattr(MyModel, c.model)
if c.model.lower() in "tacotron":
model = MyModel(num_chars=num_chars,
num_speakers=num_speakers,
r=c.r,
postnet_output_dim=c.audio['num_freq'],
decoder_output_dim=c.audio['num_mels'],
gst=c.use_gst,
memory_size=c.memory_size,
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder)
elif c.model.lower() == "tacotron2":
model = MyModel(num_chars=num_chars,
num_speakers=num_speakers,
r=c.r,
postnet_output_dim=c.audio['num_mels'],
decoder_output_dim=c.audio['num_mels'],
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder)
2019-06-26 10:59:14 +00:00
return model
2019-07-16 19:15:04 +00:00
def split_dataset(items):
is_multi_speaker = False
speakers = [item[-1] for item in items]
is_multi_speaker = len(set(speakers)) > 1
eval_split_size = 500 if len(items) * 0.01 > 500 else int(
len(items) * 0.01)
2019-07-16 19:15:04 +00:00
np.random.seed(0)
np.random.shuffle(items)
if is_multi_speaker:
items_eval = []
# most stupid code ever -- Fix it !
while len(items_eval) < eval_split_size:
speakers = [item[-1] for item in items]
2019-09-11 08:39:59 +00:00
speaker_counter = Counter(speakers)
2019-07-16 19:15:04 +00:00
item_idx = np.random.randint(0, len(items))
if speaker_counter[items[item_idx][-1]] > 1:
items_eval.append(items[item_idx])
del items[item_idx]
return items_eval, items
else:
return items[:eval_split_size], items[eval_split_size:]
2019-07-22 00:11:20 +00:00
def gradual_training_scheduler(global_step, config):
"""Setup the gradual training schedule wrt number
of active GPUs"""
num_gpus = torch.cuda.device_count()
if num_gpus == 0:
num_gpus = 1
2019-07-22 00:11:20 +00:00
new_values = None
# we set the scheduling wrt num_gpus
2019-07-22 00:11:20 +00:00
for values in config.gradual_training:
if global_step * num_gpus >= values[0]:
2019-07-22 00:11:20 +00:00
new_values = values
return new_values[1], new_values[2]
class KeepAverage():
def __init__(self):
self.avg_values = {}
self.iters = {}
def __getitem__(self, key):
return self.avg_values[key]
2019-09-11 08:39:59 +00:00
def add_value(self, name, init_val=0, init_iter=0):
self.avg_values[name] = init_val
self.iters[name] = init_iter
2019-09-11 08:39:59 +00:00
def update_value(self, name, value, weighted_avg=False):
if weighted_avg:
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value
self.iters[name] += 1
else:
2019-09-11 08:39:59 +00:00
self.avg_values[name] = self.avg_values[name] * \
self.iters[name] + value
self.iters[name] += 1
self.avg_values[name] /= self.iters[name]
2019-09-11 08:39:59 +00:00
def add_values(self, name_dict):
for key, value in name_dict.items():
self.add_value(key, init_val=value)
def update_values(self, value_dict):
for key, value in value_dict.items():
self.update_value(key, value)
2020-02-13 21:16:40 +00:00
def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None):
if restricted:
assert name in c.keys(), f' [!] {name} not defined in config.json'
if name in c.keys():
if max_val:
assert c[name] <= max_val, f' [!] {name} is larger than max value {max_val}'
if min_val:
assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}'
if enum_list:
assert c[name].lower() in enum_list, f' [!] {name} is not a valid value'
if val_type:
2020-02-14 17:00:15 +00:00
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
2020-02-13 21:16:40 +00:00
def check_config(c):
_check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str)
_check_argument('run_name', c, restricted=True, val_type=str)
_check_argument('run_description', c, val_type=str)
# AUDIO
_check_argument('audio', c, restricted=True, val_type=dict)
# audio processing parameters
_check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
_check_argument('num_freq', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
_check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
_check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000)
_check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000)
_check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
_check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
_check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
_check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
_check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
# normalization parameters
_check_argument('signal_norm', c['audio'], restricted=True, val_type=bool)
_check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool)
_check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000)
_check_argument('clip_norm', c['audio'], restricted=True, val_type=bool)
_check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000)
_check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0)
_check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool)
_check_argument('trim_db', c['audio'], restricted=True, val_type=int)
# training parameters
_check_argument('batch_size', c, restricted=True, val_type=int, min_val=1)
_check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1)
_check_argument('r', c, restricted=True, val_type=int, min_val=1)
_check_argument('gradual_training', c, restricted=False, val_type=list)
_check_argument('loss_masking', c, restricted=True, val_type=bool)
2020-02-14 16:47:33 +00:00
# _check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
2020-02-13 21:16:40 +00:00
# validation parameters
_check_argument('run_eval', c, restricted=True, val_type=bool)
_check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0)
_check_argument('test_sentences_file', c, restricted=False, val_type=str)
# optimizer
_check_argument('noam_schedule', c, restricted=False, val_type=bool)
_check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0)
_check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
_check_argument('lr', c, restricted=True, val_type=float, min_val=0)
_check_argument('wd', c, restricted=True, val_type=float, min_val=0)
_check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
_check_argument('seq_len_norm', c, restricted=True, val_type=bool)
# tacotron prenet
_check_argument('memory_size', c, restricted=True, val_type=int, min_val=-1)
_check_argument('prenet_type', c, restricted=True, val_type=str, enum_list=['original', 'bn'])
_check_argument('prenet_dropout', c, restricted=True, val_type=bool)
# attention
_check_argument('attention_type', c, restricted=True, val_type=str, enum_list=['graves', 'original'])
_check_argument('attention_heads', c, restricted=True, val_type=int)
_check_argument('attention_norm', c, restricted=True, val_type=str, enum_list=['sigmoid', 'softmax'])
_check_argument('windowing', c, restricted=True, val_type=bool)
_check_argument('use_forward_attn', c, restricted=True, val_type=bool)
_check_argument('forward_attn_mask', c, restricted=True, val_type=bool)
_check_argument('transition_agent', c, restricted=True, val_type=bool)
_check_argument('transition_agent', c, restricted=True, val_type=bool)
_check_argument('location_attn', c, restricted=True, val_type=bool)
_check_argument('bidirectional_decoder', c, restricted=True, val_type=bool)
# stopnet
_check_argument('stopnet', c, restricted=True, val_type=bool)
_check_argument('separate_stopnet', c, restricted=True, val_type=bool)
# tensorboard
_check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
_check_argument('save_step', c, restricted=True, val_type=int, min_val=1)
_check_argument('checkpoint', c, restricted=True, val_type=bool)
_check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
# dataloading
_check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=['english_cleaners', 'phoneme_cleaners', 'transliteration_cleaners', 'basic_cleaners'])
_check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)
_check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0)
_check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0)
_check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0)
_check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0)
_check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10)
# paths
_check_argument('output_path', c, restricted=True, val_type=str)
# multi-speaker gst
_check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
_check_argument('style_wav_for_test', c, restricted=True, val_type=str)
_check_argument('use_gst', c, restricted=True, val_type=bool)
# datasets - checking only the first entry
_check_argument('datasets', c, restricted=True, val_type=list)
for dataset_entry in c['datasets']:
_check_argument('name', dataset_entry, restricted=True, val_type=str)
_check_argument('path', dataset_entry, restricted=True, val_type=str)
_check_argument('meta_file_train', dataset_entry, restricted=True, val_type=str)
2020-02-14 17:00:15 +00:00
_check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)