import os import glob import torch import shutil import datetime import subprocess import importlib import numpy as np def get_git_branch(): try: out = subprocess.check_output(["git", "branch"]).decode("utf8") current = next(line for line in out.split("\n") if line.startswith("*")) current.replace("* ", "") except subprocess.CalledProcessError: current = "inside_docker" return current def get_commit_hash(): """https://stackoverflow.com/questions/14989858/get-the-current-git-hash-in-a-python-script""" # try: # subprocess.check_output(['git', 'diff-index', '--quiet', # 'HEAD']) # Verify client is clean # except: # raise RuntimeError( # " !! Commit before training to get the commit hash.") try: commit = subprocess.check_output( ['git', 'rev-parse', '--short', 'HEAD']).decode().strip() # Not copying .git folder into docker container except subprocess.CalledProcessError: commit = "0000000" print(' > Git Hash: {}'.format(commit)) return commit def create_experiment_folder(root_path, model_name, debug): """ Create a folder with the current date and time """ date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") # if debug: # commit_hash = 'debug' # else: commit_hash = get_commit_hash() output_folder = os.path.join( root_path, model_name + '-' + date_str + '-' + commit_hash) os.makedirs(output_folder, exist_ok=True) print(" > Experiment folder: {}".format(output_folder)) return output_folder def remove_experiment_folder(experiment_path): """Check folder if there is a checkpoint, otherwise remove the folder""" checkpoint_files = glob.glob(experiment_path + "/*.pth.tar") if not checkpoint_files: if os.path.exists(experiment_path): shutil.rmtree(experiment_path, ignore_errors=True) print(" ! Run is removed from {}".format(experiment_path)) else: print(" ! Run is kept in {}".format(experiment_path)) def count_parameters(model): r"""Count number of trainable parameters in a network""" return sum(p.numel() for p in model.parameters() if p.requires_grad) def split_dataset(items): is_multi_speaker = False speakers = [item[-1] for item in items] is_multi_speaker = len(set(speakers)) > 1 eval_split_size = 500 if len(items) * 0.01 > 500 else int( len(items) * 0.01) np.random.seed(0) np.random.shuffle(items) if is_multi_speaker: items_eval = [] # most stupid code ever -- Fix it ! while len(items_eval) < eval_split_size: speakers = [item[-1] for item in items] speaker_counter = Counter(speakers) item_idx = np.random.randint(0, len(items)) if speaker_counter[items[item_idx][-1]] > 1: items_eval.append(items[item_idx]) del items[item_idx] return items_eval, items else: return items[:eval_split_size], items[eval_split_size:] # from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 def sequence_mask(sequence_length, max_len=None): if max_len is None: max_len = sequence_length.data.max() batch_size = sequence_length.size(0) seq_range = torch.arange(0, max_len).long() seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) if sequence_length.is_cuda: seq_range_expand = seq_range_expand.to(sequence_length.device) seq_length_expand = ( sequence_length.unsqueeze(1).expand_as(seq_range_expand)) # B x T_max return seq_range_expand < seq_length_expand def set_init_dict(model_dict, checkpoint, c): # Partial initialization: if there is a mismatch with new and old layer, it is skipped. for k, v in checkpoint['model'].items(): if k not in model_dict: print(" | > Layer missing in the model definition: {}".format(k)) # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in checkpoint['model'].items() if k in model_dict } # 2. filter out different size layers pretrained_dict = { k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel() } # 3. skip reinit layers if c.reinit_layers is not None: for reinit_layer_name in c.reinit_layers: pretrained_dict = { k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k } # 4. overwrite entries in the existing state dict model_dict.update(pretrained_dict) print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) return model_dict def setup_model(num_chars, num_speakers, c): print(" > Using model: {}".format(c.model)) MyModel = importlib.import_module('TTS.models.' + c.model.lower()) MyModel = getattr(MyModel, c.model) if c.model.lower() in "tacotron": model = MyModel(num_chars=num_chars, num_speakers=num_speakers, r=c.r, postnet_output_dim=c.audio['num_freq'], decoder_output_dim=c.audio['num_mels'], gst=c.use_gst, memory_size=c.memory_size, attn_type=c.attention_type, attn_win=c.windowing, attn_norm=c.attention_norm, prenet_type=c.prenet_type, prenet_dropout=c.prenet_dropout, forward_attn=c.use_forward_attn, trans_agent=c.transition_agent, forward_attn_mask=c.forward_attn_mask, location_attn=c.location_attn, attn_K=c.attention_heads, separate_stopnet=c.separate_stopnet, bidirectional_decoder=c.bidirectional_decoder) elif c.model.lower() == "tacotron2": model = MyModel(num_chars=num_chars, num_speakers=num_speakers, r=c.r, postnet_output_dim=c.audio['num_mels'], decoder_output_dim=c.audio['num_mels'], attn_type=c.attention_type, attn_win=c.windowing, attn_norm=c.attention_norm, prenet_type=c.prenet_type, prenet_dropout=c.prenet_dropout, forward_attn=c.use_forward_attn, trans_agent=c.transition_agent, forward_attn_mask=c.forward_attn_mask, location_attn=c.location_attn, attn_K=c.attention_heads, separate_stopnet=c.separate_stopnet, bidirectional_decoder=c.bidirectional_decoder) return model class KeepAverage(): def __init__(self): self.avg_values = {} self.iters = {} def __getitem__(self, key): return self.avg_values[key] def items(self): return self.avg_values.items() def add_value(self, name, init_val=0, init_iter=0): self.avg_values[name] = init_val self.iters[name] = init_iter def update_value(self, name, value, weighted_avg=False): if weighted_avg: self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value self.iters[name] += 1 else: self.avg_values[name] = self.avg_values[name] * \ self.iters[name] + value self.iters[name] += 1 self.avg_values[name] /= self.iters[name] def add_values(self, name_dict): for key, value in name_dict.items(): self.add_value(key, init_val=value) def update_values(self, value_dict): for key, value in value_dict.items(): self.update_value(key, value) def _check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None): if alternative in c.keys() and c[alternative] is not None: return if restricted: assert name in c.keys(), f' [!] {name} not defined in config.json' if name in c.keys(): if max_val: assert c[name] <= max_val, f' [!] {name} is larger than max value {max_val}' if min_val: assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}' if enum_list: assert c[name].lower() in enum_list, f' [!] {name} is not a valid value' if val_type: assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}' def check_config(c): _check_argument('model', c, enum_list=['tacotron', 'tacotron2'], restricted=True, val_type=str) _check_argument('run_name', c, restricted=True, val_type=str) _check_argument('run_description', c, val_type=str) # AUDIO _check_argument('audio', c, restricted=True, val_type=dict) # audio processing parameters _check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056) _check_argument('num_freq', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058) _check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000) _check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length') _check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length') _check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1) _check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10) _check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000) _check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5) _check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000) # vocabulary parameters _check_argument('characters', c, restricted=False, val_type=dict) _check_argument('pad', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) _check_argument('eos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) _check_argument('bos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) _check_argument('characters', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) _check_argument('phonemes', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) _check_argument('punctuations', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) # normalization parameters _check_argument('signal_norm', c['audio'], restricted=True, val_type=bool) _check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool) _check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000) _check_argument('clip_norm', c['audio'], restricted=True, val_type=bool) _check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000) _check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0) _check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool) _check_argument('trim_db', c['audio'], restricted=True, val_type=int) # training parameters _check_argument('batch_size', c, restricted=True, val_type=int, min_val=1) _check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1) _check_argument('r', c, restricted=True, val_type=int, min_val=1) _check_argument('gradual_training', c, restricted=False, val_type=list) _check_argument('loss_masking', c, restricted=True, val_type=bool) # _check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100) # validation parameters _check_argument('run_eval', c, restricted=True, val_type=bool) _check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0) _check_argument('test_sentences_file', c, restricted=False, val_type=str) # optimizer _check_argument('noam_schedule', c, restricted=False, val_type=bool) _check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0) _check_argument('epochs', c, restricted=True, val_type=int, min_val=1) _check_argument('lr', c, restricted=True, val_type=float, min_val=0) _check_argument('wd', c, restricted=True, val_type=float, min_val=0) _check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0) _check_argument('seq_len_norm', c, restricted=True, val_type=bool) # tacotron prenet _check_argument('memory_size', c, restricted=True, val_type=int, min_val=-1) _check_argument('prenet_type', c, restricted=True, val_type=str, enum_list=['original', 'bn']) _check_argument('prenet_dropout', c, restricted=True, val_type=bool) # attention _check_argument('attention_type', c, restricted=True, val_type=str, enum_list=['graves', 'original']) _check_argument('attention_heads', c, restricted=True, val_type=int) _check_argument('attention_norm', c, restricted=True, val_type=str, enum_list=['sigmoid', 'softmax']) _check_argument('windowing', c, restricted=True, val_type=bool) _check_argument('use_forward_attn', c, restricted=True, val_type=bool) _check_argument('forward_attn_mask', c, restricted=True, val_type=bool) _check_argument('transition_agent', c, restricted=True, val_type=bool) _check_argument('transition_agent', c, restricted=True, val_type=bool) _check_argument('location_attn', c, restricted=True, val_type=bool) _check_argument('bidirectional_decoder', c, restricted=True, val_type=bool) # stopnet _check_argument('stopnet', c, restricted=True, val_type=bool) _check_argument('separate_stopnet', c, restricted=True, val_type=bool) # tensorboard _check_argument('print_step', c, restricted=True, val_type=int, min_val=1) _check_argument('save_step', c, restricted=True, val_type=int, min_val=1) _check_argument('checkpoint', c, restricted=True, val_type=bool) _check_argument('tb_model_param_stats', c, restricted=True, val_type=bool) # dataloading # pylint: disable=import-outside-toplevel from TTS.utils.text import cleaners _check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners)) _check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool) _check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0) _check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0) _check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0) _check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0) _check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10) # paths _check_argument('output_path', c, restricted=True, val_type=str) # multi-speaker gst _check_argument('use_speaker_embedding', c, restricted=True, val_type=bool) _check_argument('style_wav_for_test', c, restricted=True, val_type=str) _check_argument('use_gst', c, restricted=True, val_type=bool) # datasets - checking only the first entry _check_argument('datasets', c, restricted=True, val_type=list) for dataset_entry in c['datasets']: _check_argument('name', dataset_entry, restricted=True, val_type=str) _check_argument('path', dataset_entry, restricted=True, val_type=str) _check_argument('meta_file_train', dataset_entry, restricted=True, val_type=str) _check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)