diff --git a/TTS/bin/train_speedy_speech.py b/TTS/bin/train_speedy_speech.py new file mode 100644 index 00000000..39466ffd --- /dev/null +++ b/TTS/bin/train_speedy_speech.py @@ -0,0 +1,619 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import argparse +import glob +import os +import sys +import time +import traceback +import numpy as np +from random import randrange + +import torch +# DISTRIBUTED +from torch.nn.parallel import DistributedDataParallel as DDP_th +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from TTS.tts.datasets.preprocess import load_meta_data +from TTS.tts.datasets.TTSDataset import MyDataset +from TTS.tts.layers.losses import SpeedySpeechLoss +from TTS.tts.utils.generic_utils import check_config_tts, setup_model +from TTS.tts.utils.io import save_best_model, save_checkpoint +from TTS.tts.utils.measures import alignment_diagonal_score +from TTS.tts.utils.speakers import parse_speakers +from TTS.tts.utils.synthesis import synthesis +from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram +from TTS.utils.audio import AudioProcessor +from TTS.utils.console_logger import ConsoleLogger +from TTS.utils.distribute import init_distributed, reduce_tensor +from TTS.utils.generic_utils import (KeepAverage, count_parameters, + create_experiment_folder, get_git_branch, + remove_experiment_folder, set_init_dict) +from TTS.utils.io import copy_config_file, load_config +from TTS.utils.radam import RAdam +from TTS.utils.tensorboard_logger import TensorboardLogger +from TTS.utils.training import NoamLR, setup_torch_training_env + +use_cuda, num_gpus = setup_torch_training_env(True, False) + + +def setup_loader(ap, r, is_val=False, verbose=False): + if is_val and not c.run_eval: + loader = None + else: + dataset = MyDataset( + r, + c.text_cleaner, + compute_linear_spec=False, + meta_data=meta_data_eval if is_val else meta_data_train, + ap=ap, + tp=c.characters if 'characters' in c.keys() else None, + add_blank=c['add_blank'] if 'add_blank' in c.keys() else False, + batch_group_size=0 if is_val else c.batch_group_size * + c.batch_size, + min_seq_len=c.min_seq_len, + max_seq_len=c.max_seq_len, + phoneme_cache_path=c.phoneme_cache_path, + use_phonemes=c.use_phonemes, + phoneme_language=c.phoneme_language, + enable_eos_bos=c.enable_eos_bos_chars, + use_noise_augment=not is_val, + verbose=verbose, + speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None) + + if c.use_phonemes and c.compute_input_seq_cache: + # precompute phonemes to have a better estimate of sequence lengths. + dataset.compute_input_seq(c.num_loader_workers) + dataset.sort_items() + + sampler = DistributedSampler(dataset) if num_gpus > 1 else None + loader = DataLoader( + dataset, + batch_size=c.eval_batch_size if is_val else c.batch_size, + shuffle=False, + collate_fn=dataset.collate_fn, + drop_last=False, + sampler=sampler, + num_workers=c.num_val_loader_workers + if is_val else c.num_loader_workers, + pin_memory=False) + return loader + + +def format_data(data): + # setup input data + text_input = data[0] + text_lengths = data[1] + speaker_names = data[2] + mel_input = data[4].permute(0, 2, 1) # B x D x T + mel_lengths = data[5] + item_idx = data[7] + attn_mask = data[9] + avg_text_length = torch.mean(text_lengths.float()) + avg_spec_length = torch.mean(mel_lengths.float()) + + if c.use_speaker_embedding: + if c.use_external_speaker_embedding_file: + # return precomputed embedding vector + speaker_c = data[8] + else: + # return speaker_id to be used by an embedding layer + speaker_c = [ + speaker_mapping[speaker_name] for speaker_name in speaker_names + ] + speaker_c = torch.LongTensor(speaker_c) + else: + speaker_c = None + # compute durations from attention mask + durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) + for idx, am in enumerate(attn_mask): + # compute raw durations + c_idxs = am[:, :text_lengths[idx], :mel_lengths[idx]].max(1)[1] + # c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True) + c_idxs, counts = torch.unique(c_idxs, return_counts=True) + dur = torch.ones([text_lengths[idx]]).to(counts.dtype) + dur[c_idxs] = counts + # smooth the durations and set any 0 duration to 1 + # by cutting off from the largest duration indeces. + extra_frames = dur.sum() - mel_lengths[idx] + largest_idxs = torch.argsort(-dur)[:extra_frames] + dur[largest_idxs] -= 1 + assert dur.sum() == mel_lengths[idx], f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" + durations[idx, :text_lengths[idx]] = dur + # dispatch data to GPU + if use_cuda: + text_input = text_input.cuda(non_blocking=True) + text_lengths = text_lengths.cuda(non_blocking=True) + mel_input = mel_input.cuda(non_blocking=True) + mel_lengths = mel_lengths.cuda(non_blocking=True) + if speaker_c is not None: + speaker_c = speaker_c.cuda(non_blocking=True) + attn_mask = attn_mask.cuda(non_blocking=True) + durations = durations.cuda(non_blocking=True) + return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\ + avg_text_length, avg_spec_length, attn_mask, durations, item_idx + + +def train(data_loader, model, criterion, optimizer, scheduler, + ap, global_step, epoch): + + model.train() + epoch_time = 0 + keep_avg = KeepAverage() + if use_cuda: + batch_n_iter = int( + len(data_loader.dataset) / (c.batch_size * num_gpus)) + else: + batch_n_iter = int(len(data_loader.dataset) / c.batch_size) + end_time = time.time() + c_logger.print_train_start() + scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None + for num_iter, data in enumerate(data_loader): + start_time = time.time() + + # format data + text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\ + avg_text_length, avg_spec_length, attn_mask, dur_target, item_idx = format_data(data) + + loader_time = time.time() - end_time + + global_step += 1 + optimizer.zero_grad() + + # forward pass model + with torch.cuda.amp.autocast(enabled=c.mixed_precision): + decoder_output, dur_output, alignments = model.forward( + text_input, text_lengths, mel_lengths, dur_target, g=speaker_c) + + # compute loss + loss_dict = criterion(decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths) + + # backward pass with loss scaling + if c.mixed_precision: + scaler.scale(loss_dict['loss']).backward() + scaler.unscale_(optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), + c.grad_clip) + scaler.step(optimizer) + scaler.update() + else: + loss_dict['loss'].backward() + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), + c.grad_clip) + optimizer.step() + + # setup lr + if c.noam_schedule: + scheduler.step() + + # current_lr + current_lr = optimizer.param_groups[0]['lr'] + + # compute alignment error (the lower the better ) + align_error = 1 - alignment_diagonal_score(alignments, binary=True) + loss_dict['align_error'] = align_error + + step_time = time.time() - start_time + epoch_time += step_time + + # aggregate losses from processes + if num_gpus > 1: + loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, num_gpus) + loss_dict['loss_ssim'] = reduce_tensor(loss_dict['loss_ssim'].data, num_gpus) + loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus) + loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus) + + # detach loss values + loss_dict_new = dict() + for key, value in loss_dict.items(): + if isinstance(value, (int, float)): + loss_dict_new[key] = value + else: + loss_dict_new[key] = value.item() + loss_dict = loss_dict_new + + # update avg stats + update_train_values = dict() + for key, value in loss_dict.items(): + update_train_values['avg_' + key] = value + update_train_values['avg_loader_time'] = loader_time + update_train_values['avg_step_time'] = step_time + keep_avg.update_values(update_train_values) + + # print training progress + if global_step % c.print_step == 0: + log_dict = { + + "avg_spec_length": [avg_spec_length, 1], # value, precision + "avg_text_length": [avg_text_length, 1], + "step_time": [step_time, 4], + "loader_time": [loader_time, 2], + "current_lr": current_lr, + } + c_logger.print_train_step(batch_n_iter, num_iter, global_step, + log_dict, loss_dict, keep_avg.avg_values) + + if args.rank == 0: + # Plot Training Iter Stats + # reduce TB load + if global_step % c.tb_plot_step == 0: + iter_stats = { + "lr": current_lr, + "grad_norm": grad_norm, + "step_time": step_time + } + iter_stats.update(loss_dict) + tb_logger.tb_train_iter_stats(global_step, iter_stats) + + if global_step % c.save_step == 0: + if c.checkpoint: + # save model + save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, + model_loss=loss_dict['loss']) + + # wait all kernels to be completed + torch.cuda.synchronize() + + # Diagnostic visualizations + idx = np.random.randint(mel_targets.shape[0]) + pred_spec = decoder_output[idx].detach().data.cpu().numpy().T + gt_spec = mel_targets[idx].data.cpu().numpy().T + align_img = alignments[idx].data.cpu() + + figures = { + "prediction": plot_spectrogram(pred_spec, ap), + "ground_truth": plot_spectrogram(gt_spec, ap), + "alignment": plot_alignment(align_img), + } + + tb_logger.tb_train_figures(global_step, figures) + + # Sample audio + train_audio = ap.inv_melspectrogram(pred_spec.T) + tb_logger.tb_train_audios(global_step, + {'TrainAudio': train_audio}, + c.audio["sample_rate"]) + end_time = time.time() + + # print epoch stats + c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) + + # Plot Epoch Stats + if args.rank == 0: + epoch_stats = {"epoch_time": epoch_time} + epoch_stats.update(keep_avg.avg_values) + tb_logger.tb_train_epoch_stats(global_step, epoch_stats) + if c.tb_model_param_stats: + tb_logger.tb_model_weights(model, global_step) + return keep_avg.avg_values, global_step + + +@torch.no_grad() +def evaluate(data_loader, model, criterion, ap, global_step, epoch): + model.eval() + epoch_time = 0 + keep_avg = KeepAverage() + c_logger.print_eval_start() + if data_loader is not None: + for num_iter, data in enumerate(data_loader): + start_time = time.time() + + # format data + text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\ + avg_text_length, avg_spec_length, attn_mask, dur_target, item_idx = format_data(data) + + # forward pass model + with torch.cuda.amp.autocast(enabled=c.mixed_precision): + decoder_output, dur_output, alignments = model.forward( + text_input, text_lengths, mel_lengths, dur_target, g=speaker_c) + + # compute loss + loss_dict = criterion(decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths) + + # step time + step_time = time.time() - start_time + epoch_time += step_time + + # compute alignment score + align_error = 1 - alignment_diagonal_score(alignments) + loss_dict['align_error'] = align_error + + # aggregate losses from processes + if num_gpus > 1: + loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, num_gpus) + loss_dict['loss_ssim'] = reduce_tensor(loss_dict['loss_ssim'].data, num_gpus) + loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus) + loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus) + + # detach loss values + loss_dict_new = dict() + for key, value in loss_dict.items(): + if isinstance(value, (int, float)): + loss_dict_new[key] = value + else: + loss_dict_new[key] = value.item() + loss_dict = loss_dict_new + + # update avg stats + update_train_values = dict() + for key, value in loss_dict.items(): + update_train_values['avg_' + key] = value + keep_avg.update_values(update_train_values) + + if c.print_eval: + c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) + + if args.rank == 0: + # Diagnostic visualizations + idx = np.random.randint(mel_targets.shape[0]) + pred_spec = decoder_output[idx].detach().data.cpu().numpy().T + gt_spec = mel_targets[idx].data.cpu().numpy().T + align_img = alignments[idx].data.cpu() + + eval_figures = { + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False) + } + + # Sample audio + eval_audio = ap.inv_melspectrogram(pred_spec.T) + tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, + c.audio["sample_rate"]) + + # Plot Validation Stats + tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) + tb_logger.tb_eval_figures(global_step, eval_figures) + + if args.rank == 0 and epoch >= c.test_delay_epochs: + if c.test_sentences_file is None: + test_sentences = [ + "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", + "Be a voice, not an echo.", + "I'm sorry Dave. I'm afraid I can't do that.", + "This cake is great. It's so delicious and moist.", + "Prior to November 22, 1963." + ] + else: + with open(c.test_sentences_file, "r") as f: + test_sentences = [s.strip() for s in f.readlines()] + + # test sentences + test_audios = {} + test_figures = {} + print(" | > Synthesizing test sentences") + if c.use_speaker_embedding: + if c.use_external_speaker_embedding_file: + speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding'] + speaker_id = None + else: + speaker_id = 0 + speaker_embedding = None + else: + speaker_id = None + speaker_embedding = None + + style_wav = c.get("style_wav_for_test") + for idx, test_sentence in enumerate(test_sentences): + try: + wav, alignment, _, postnet_output, _, _ = synthesis( + model, + test_sentence, + c, + use_cuda, + ap, + speaker_id=speaker_id, + speaker_embedding=speaker_embedding, + style_wav=style_wav, + truncated=False, + enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument + use_griffin_lim=True, + do_trim_silence=False) + + file_path = os.path.join(AUDIO_PATH, str(global_step)) + os.makedirs(file_path, exist_ok=True) + file_path = os.path.join(file_path, + "TestSentence_{}.wav".format(idx)) + ap.save_wav(wav, file_path) + test_audios['{}-audio'.format(idx)] = wav + test_figures['{}-prediction'.format(idx)] = plot_spectrogram( + postnet_output, ap) + test_figures['{}-alignment'.format(idx)] = plot_alignment( + alignment) + except: #pylint: disable=bare-except + print(" !! Error creating Test Sentence -", idx) + traceback.print_exc() + tb_logger.tb_test_audios(global_step, test_audios, + c.audio['sample_rate']) + tb_logger.tb_test_figures(global_step, test_figures) + return keep_avg.avg_values + + +# FIXME: move args definition/parsing inside of main? +def main(args): # pylint: disable=redefined-outer-name + # pylint: disable=global-variable-undefined + global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping + # Audio processor + ap = AudioProcessor(**c.audio) + if 'characters' in c.keys(): + symbols, phonemes = make_symbols(**c.characters) + + # DISTRUBUTED + if num_gpus > 1: + init_distributed(args.rank, num_gpus, args.group_id, + c.distributed["backend"], c.distributed["url"]) + num_chars = len(phonemes) if c.use_phonemes else len(symbols) + + # load data instances + meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True) + + # set the portion of the data used for training if set in config.json + if 'train_portion' in c.keys(): + meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)] + if 'eval_portion' in c.keys(): + meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)] + + # parse speakers + num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH) + + # setup model + model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim) + optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9) + criterion = SpeedySpeechLoss(c) + + if args.restore_path: + checkpoint = torch.load(args.restore_path, map_location='cpu') + try: + # TODO: fix optimizer init, model.cuda() needs to be called before + # optimizer restore + optimizer.load_state_dict(checkpoint['optimizer']) + if c.reinit_layers: + raise RuntimeError + model.load_state_dict(checkpoint['model']) + except: #pylint: disable=bare-except + print(" > Partial model initialization.") + model_dict = model.state_dict() + model_dict = set_init_dict(model_dict, checkpoint['model'], c) + model.load_state_dict(model_dict) + del model_dict + + for group in optimizer.param_groups: + group['initial_lr'] = c.lr + print(" > Model restored from step %d" % checkpoint['step'], + flush=True) + args.restore_step = checkpoint['step'] + else: + args.restore_step = 0 + + if use_cuda: + model.cuda() + criterion.cuda() + + # DISTRUBUTED + if num_gpus > 1: + model = DDP_th(model, device_ids=[args.rank]) + + if c.noam_schedule: + scheduler = NoamLR(optimizer, + warmup_steps=c.warmup_steps, + last_epoch=args.restore_step - 1) + else: + scheduler = None + + num_params = count_parameters(model) + print("\n > Model has {} parameters".format(num_params), flush=True) + + if 'best_loss' not in locals(): + best_loss = float('inf') + + # define dataloaders + train_loader = setup_loader(ap, 1, is_val=False, verbose=True) + eval_loader = setup_loader(ap, 1, is_val=True, verbose=True) + + global_step = args.restore_step + for epoch in range(0, c.epochs): + c_logger.print_epoch_start(epoch, c.epochs) + train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer, + scheduler, ap, global_step, + epoch) + eval_avg_loss_dict = evaluate(eval_loader , model, criterion, ap, global_step, epoch) + c_logger.print_epoch_end(epoch, eval_avg_loss_dict) + target_loss = train_avg_loss_dict['avg_loss'] + if c.run_eval: + target_loss = eval_avg_loss_dict['avg_loss'] + best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r, + OUT_PATH) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--continue_path', + type=str, + help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', + default='', + required='--config_path' not in sys.argv) + parser.add_argument( + '--restore_path', + type=str, + help='Model file to be restored. Use to finetune a model.', + default='') + parser.add_argument( + '--config_path', + type=str, + help='Path to config file for training.', + required='--continue_path' not in sys.argv + ) + parser.add_argument('--debug', + type=bool, + default=False, + help='Do not verify commit integrity to run training.') + + # DISTRUBUTED + parser.add_argument( + '--rank', + type=int, + default=0, + help='DISTRIBUTED: process rank for distributed training.') + parser.add_argument('--group_id', + type=str, + default="", + help='DISTRIBUTED: process group id.') + args = parser.parse_args() + + if args.continue_path != '': + args.output_path = args.continue_path + args.config_path = os.path.join(args.continue_path, 'config.json') + list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv + latest_model_file = max(list_of_files, key=os.path.getctime) + args.restore_path = latest_model_file + print(f" > Training continues for {args.restore_path}") + + # setup output paths and read configs + c = load_config(args.config_path) + # check_config(c) + check_config_tts(c) + _ = os.path.dirname(os.path.realpath(__file__)) + + if c.mixed_precision: + print(" > Mixed precision enabled.") + + OUT_PATH = args.continue_path + if args.continue_path == '': + OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug) + + AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios') + + c_logger = ConsoleLogger() + + if args.rank == 0: + os.makedirs(AUDIO_PATH, exist_ok=True) + new_fields = {} + if args.restore_path: + new_fields["restore_path"] = args.restore_path + new_fields["github_branch"] = get_git_branch() + copy_config_file(args.config_path, + os.path.join(OUT_PATH, 'config.json'), new_fields) + os.chmod(AUDIO_PATH, 0o775) + os.chmod(OUT_PATH, 0o775) + + LOG_DIR = OUT_PATH + tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS') + + # write model desc to tensorboard + tb_logger.tb_add_text('model-description', c['run_description'], 0) + + try: + main(args) + except KeyboardInterrupt: + remove_experiment_folder(OUT_PATH) + try: + sys.exit(0) + except SystemExit: + os._exit(0) # pylint: disable=protected-access + except Exception: # pylint: disable=broad-except + remove_experiment_folder(OUT_PATH) + traceback.print_exc() + sys.exit(1) diff --git a/TTS/tts/configs/speedy_speech_ljspeech.json b/TTS/tts/configs/speedy_speech_ljspeech.json new file mode 100644 index 00000000..5d2fd260 --- /dev/null +++ b/TTS/tts/configs/speedy_speech_ljspeech.json @@ -0,0 +1,149 @@ +{ + "model": "speedy_speech", + "run_name": "speedy-speech-ljspeech", + "run_description": "speedy-speech model for LJSpeech dataset.", + + // AUDIO PARAMETERS + "audio":{ + // stft parameters + "fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame. + "win_length": 1024, // stft window length in ms. + "hop_length": 256, // stft window hop-lengh in ms. + "frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used. + "frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used. + + // Audio processing parameters + "sample_rate": 22050, // DATASET-RELATED: wav sample-rate. + "preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis. + "ref_level_db": 20, // reference level db, theoretically 20db is the sound of air. + + // Silence trimming + "do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (true), TWEB (false), Nancy (true) + "trim_db": 60, // threshold for timming silence. Set this according to your dataset. + + // Griffin-Lim + "power": 1.5, // value to sharpen wav signals after GL algorithm. + "griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation. + + // MelSpectrogram parameters + "num_mels": 80, // size of the mel spec frame. + "mel_fmin": 50.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!! + "mel_fmax": 7600.0, // maximum freq level for mel-spec. Tune for dataset!! + "spec_gain": 1, + + // Normalization parameters + "signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params. + "min_level_db": -100, // lower bound for normalization + "symmetric_norm": true, // move normalization to range [-1, 1] + "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] + "clip_norm": true, // clip normalized values into the range. + "stats_path": "/home/erogol/Data/LJSpeech-1.1/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored + }, + + // VOCABULARY PARAMETERS + // if custom character set is not defined, + // default set in symbols.py is used + // "characters":{ + // "pad": "_", + // "eos": "&", + // "bos": "*", + // "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZÇÃÀÁÂÊÉÍÓÔÕÚÛabcdefghijklmnopqrstuvwxyzçãàáâêéíóôõúû!(),-.:;? ", + // "punctuations":"!'(),-.:;? ", + // "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ'̃' " + // }, + + "add_blank": false, // if true add a new token after each token of the sentence. This increases the size of the input sequence, but has considerably improved the prosody of the GlowTTS model. + + // DISTRIBUTED TRAINING + "distributed":{ + "backend": "nccl", + "url": "tcp:\/\/localhost:54321" + }, + + "reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers. + + // MODEL PARAMETERS + "positional_encoding": true, + "encoder_type": "residual_conv_bn", + "encoder_params":{ + "kernel_size": 4, + "dilations": [1, 2, 4, 1, 2, 4, 1, 2, 4, 1, 2, 4, 1], + "num_conv_blocks": 2, + "num_res_blocks": 13 + }, + "decoder_residual_conv_bn_params":{ + "kernel_size": 4, + "dilations": [1, 2, 4, 8, 1, 2, 4, 8, 1, 2, 4, 8, 1, 2, 4, 8, 1], + "num_conv_blocks": 2, + "num_res_blocks": 17 + }, + + // TRAINING + "batch_size":64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'. + "eval_batch_size":32, + "r": 1, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled. + "loss_masking": true, // enable / disable loss masking against the sequence padding. + + // LOSS PARAMETERS + "ssim_alpha": 1, + "l1_alpha": 1, + "huber_alpha": 1, + + // VALIDATION + "run_eval": true, + "test_delay_epochs": -1, //Until attention is aligned, testing only wastes computation time. + "test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences. + + // OPTIMIZER + "noam_schedule": true, // use noam warmup and lr schedule. + "grad_clip": 1.0, // upper limit for gradients for clipping. + "epochs": 10000, // total number of epochs to train. + "lr": 0.002, // Initial learning rate. If Noam decay is active, maximum learning rate. + "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" + + // TENSORBOARD and LOGGING + "print_step": 25, // Number of steps to log training on console. + "tb_plot_step": 100, // Number of steps to plot TB training figures. + "print_eval": false, // If True, it prints intermediate loss values in evalulation. + "save_step": 5000, // Number of training steps expected to save traninpg stats and checkpoints. + "checkpoint": true, // If true, it saves checkpoints per "save_step" + "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.:set n + "mixed_precision": false, + + // DATA LOADING + "text_cleaner": "english_cleaners", + "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. + "num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values. + "num_val_loader_workers": 8, // number of evaluation data loader processes. + "batch_group_size": 0, //Number of batches to shuffle after bucketing. + "min_seq_len": 2, // DATASET-RELATED: minimum text length to use in training + "max_seq_len": 300, // DATASET-RELATED: maximum text length + "compute_f0": false, // compute f0 values in data-loader + "compute_input_seq_cache": false, // if true, text sequences are computed before starting training. If phonemes are enabled, they are also computed at this stage. + + // PATHS + "output_path": "/home/erogol/Models/ljspeech/", + + // PHONEMES + "phoneme_cache_path": "/home/erogol/Models/ljspeech_phonemes/", // phoneme computation is slow, therefore, it caches results in the given folder. + "use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronoun[ciation. + "phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages + + // MULTI-SPEAKER and GST + "use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning. + "use_external_speaker_embedding_file": false, // if true, forces the model to use external embedding per sample instead of nn.embeddings, that is, it supports external embeddings such as those used at: https://arxiv.org/abs /1806.04558 + "external_speaker_embedding_file": "/home/erogol/Data/libritts/speakers.json", // if not null and use_external_speaker_embedding_file is true, it is used to load a specific embedding file and thus uses these embeddings instead of nn.embeddings, that is, it supports external embeddings such as those used at: https://arxiv.org/abs /1806.04558 + + + // DATASETS + "datasets": // List of datasets. They all merged and they get different s$ + [ + { + "name": "ljspeech", + "path": "/home/erogol/Data/LJSpeech-1.1/", + "meta_file_train": "metadata.csv", + "meta_file_val": null, + "meta_file_attn_mask": "/home/erogol/Data/LJSpeech-1.1/metadata_attn_mask.txt" // created by bin/compute_attention_masks.py + } + ] +} \ No newline at end of file diff --git a/TTS/tts/layers/speedy_speech/__init__.py b/TTS/tts/layers/speedy_speech/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/TTS/tts/layers/speedy_speech/decoder.py b/TTS/tts/layers/speedy_speech/decoder.py new file mode 100644 index 00000000..f4f23840 --- /dev/null +++ b/TTS/tts/layers/speedy_speech/decoder.py @@ -0,0 +1,41 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.tts.layers.glow_tts.transformer import Transformer +from TTS.tts.layers.glow_tts.glow import ConvLayerNorm +from TTS.tts.utils.generic_utils import sequence_mask +from TTS.tts.layers.generic.res_conv_bn import ResidualConvBNBlock, ConvBNBlock + + +class Decoder(nn.Module): + """Decodes the expanded phoneme encoding into spectrograms + Shapes: + - input: (B, C, T) + """ + # pylint: disable=dangerous-default-value + def __init__( + self, + out_channels, + hidden_channels, + residual_conv_bn_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17 + }): + super().__init__() + + self.decoder = ResidualConvBNBlock(hidden_channels, + **residual_conv_bn_params) + + self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1) + self.post_net = nn.Sequential( + ConvBNBlock(hidden_channels, 4, 1, num_conv_blocks=2), + nn.Conv1d(hidden_channels, out_channels, 1), + ) + + def forward(self, x, x_mask, g=None): + o = self.decoder(x, x_mask) + o = self.post_conv(o) + x + return self.post_net(o) \ No newline at end of file diff --git a/TTS/tts/layers/speedy_speech/duration_predictor.py b/TTS/tts/layers/speedy_speech/duration_predictor.py new file mode 100644 index 00000000..9f83d94d --- /dev/null +++ b/TTS/tts/layers/speedy_speech/duration_predictor.py @@ -0,0 +1,27 @@ +from torch import nn + +from TTS.tts.layers.generic.res_conv_bn import ConvBN + + +class DurationPredictor(nn.Module): + """Predicts phoneme log durations based on the encoder outputs""" + def __init__(self, hidden_channels): + super().__init__() + + self.layers = nn.ModuleList([ + ConvBN(hidden_channels, 4, 1), + ConvBN(hidden_channels, 3, 1), + ConvBN(hidden_channels, 1, 1), + nn.Conv1d(hidden_channels, 1, 1) + ]) + + def forward(self, x, x_mask): + """Outputs interpreted as log(durations) + To get actual durations, do exp transformation + :param x: + :return: + """ + o = x + for layer in self.layers: + o = layer(o) * x_mask + return o diff --git a/TTS/tts/layers/speedy_speech/encoder.py b/TTS/tts/layers/speedy_speech/encoder.py new file mode 100644 index 00000000..0ec222f9 --- /dev/null +++ b/TTS/tts/layers/speedy_speech/encoder.py @@ -0,0 +1,153 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +from TTS.tts.layers.glow_tts.transformer import Transformer +from TTS.tts.layers.glow_tts.glow import ConvLayerNorm +from TTS.tts.layers.generic.res_conv_bn import ResidualConvBNBlock + + +class PositionalEncoding(nn.Module): + """Sinusoidal positional encoding for non-recurrent neural networks. + Implementation based on "Attention Is All You Need" + Args: + dropout (float): dropout parameter + dim (int): embedding size + """ + + def __init__(self, dim, dropout=0.0, max_len=5000): + if dim % 2 != 0: + raise ValueError("Cannot use sin/cos positional encoding with " + "odd dim (got dim={:d})".format(dim)) + pe = torch.zeros(max_len, dim) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) * + -(math.log(10000.0) / dim))) + pe[:, 0::2] = torch.sin(position.float() * div_term) + pe[:, 1::2] = torch.cos(position.float() * div_term) + pe = pe.unsqueeze(0).transpose(1, 2) + super(PositionalEncoding, self).__init__() + self.register_buffer('pe', pe) + if dropout > 0: + self.dropout = nn.Dropout(p=dropout) + self.dim = dim + + def forward(self, x, step=None): + """Embed inputs. + Args: + x (FloatTensor): Sequence of word vectors + ``(seq_len, batch_size, self.dim)`` + step (int or NoneType): If stepwise (``seq_len = 1``), use + the encoding for this position. + + Shapes: + x: B x C x T + """ + + x = x * math.sqrt(self.dim) + if step is None: + if self.pe.size(2) < x.size(2): + raise RuntimeError( + f"Sequence is {x.size(2)} but PositionalEncoding is" + f" limited to {self.pe.size(2)}. See max_len argument." + ) + x = x + self.pe[:, : ,:x.size(2)] + else: + x = x + self.pe[:, :, step] + if hasattr(self, 'dropout'): + x = self.dropout(x) + return x + + +class Encoder(nn.Module): + # pylint: disable=dangerous-default-value + def __init__( + self, + hidden_channels, + out_channels, + encoder_type='residual_conv_bn', + encoder_params = { + "kernel_size": 4, + "dilations": 4 * [1, 2, 4] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 13 + }, + c_in_channels=0): + """Speedy-Speech encoder using Transformers or Residual BN Convs internally. + + Args: + num_chars (int): number of characters. + out_channels (int): number of output channels. + hidden_channels (int): encoder's embedding size. + encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'. + encoder_params (dict): model parameters for specified encoder type. + c_in_channels (int): number of channels for conditional input. + + Note: + Default encoder_params... + + for 'transformer' + encoder_params={ + 'hidden_channels_ffn': 768, + 'num_heads': 2, + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "rel_attn_window_size": 4, + "input_length": None + }, + + for 'residual_conv_bn' + encoder_params = { + "kernel_size": 4, + "dilations": 4 * [1, 2, 4] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 13 + } + + Shapes: + - input: (B, C, T) + """ + super().__init__() + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.encoder_type = encoder_type + self.c_in_channels = c_in_channels + + # init encoder + if encoder_type.lower() == "transformer": + # optional convolutional prenet + self.pre = ConvLayerNorm(hidden_channels, + hidden_channels, + hidden_channels, + kernel_size=5, + num_layers=3, + dropout_p=0.5) + # text encoder + self.encoder = Transformer(hidden_channels, **encoder_params) + elif encoder_type.lower() == 'residual_conv_bn': + self.pre = nn.Sequential( + nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU()) + self.encoder = ResidualConvBNBlock(hidden_channels, + **encoder_params) + else: + raise NotImplementedError(' [!] encoder type not implemented.') + + # final projection layers + self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1) + self.post_bn = nn.BatchNorm1d(hidden_channels) + self.post_conv2 = nn.Conv1d(hidden_channels, out_channels, 1) + + def forward(self, x, x_mask, g=None): + if self.encoder_type == 'transformer': + o = self.pre(x, x_mask) + else: + o = self.pre(x) * x_mask + o = self.encoder(o, x_mask) + o = self.post_conv(o + x) + o = self.post_bn(o) + o = F.relu(o) + o = self.post_conv2(o) + # [B, C, T] + return o * x_mask diff --git a/TTS/tts/models/speedy_speech.py b/TTS/tts/models/speedy_speech.py new file mode 100644 index 00000000..6ce96892 --- /dev/null +++ b/TTS/tts/models/speedy_speech.py @@ -0,0 +1,123 @@ +import torch +from torch import nn +from TTS.tts.layers.speedy_speech.decoder import Decoder +from TTS.tts.layers.speedy_speech.duration_predictor import DurationPredictor +from TTS.tts.layers.speedy_speech.encoder import Encoder, PositionalEncoding +from TTS.tts.utils.generic_utils import sequence_mask +from TTS.tts.layers.glow_tts.monotonic_align import generate_path + + +class SpeedySpeech(nn.Module): + def __init__( + self, + num_chars, + out_channels, + hidden_channels, + positional_encoding=True, + length_scale=1, + encoder_type='residual_conv_bn', + encoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 13 + }, + decoder_residual_conv_bn_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17 + }, + c_in_channels=0): + super().__init__() + self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale + self.emb = nn.Embedding(num_chars, hidden_channels) + self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, + encoder_params, c_in_channels) + if positional_encoding: + self.pos_encoder = PositionalEncoding(hidden_channels) + self.decoder = Decoder(out_channels, hidden_channels, + decoder_residual_conv_bn_params) + self.duration_predictor = DurationPredictor(hidden_channels) + + def expand_encoder_outputs(self, en, dr, x_mask, y_mask): + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype) + o_en_ex = torch.matmul( + attn.squeeze(1).transpose(1, 2), en.transpose(1, + 2)).transpose(1, 2) + return o_en_ex, attn + + def format_durations(self, o_dr_log, x_mask): + o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale + o_dr[o_dr < 1] = 1.0 + o_dr = torch.round(o_dr) + return o_dr + + def forward(self, x, x_lengths, y_lengths, dr, g=None): + """ + docstring + """ + # [B, T, C] + x_emb = self.emb(x) + # [B, C, T] + x_emb = torch.transpose(x_emb, 1, -1) + + # compute sequence masks + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), + 1).to(x.dtype) + + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), + 1).to(x_mask.dtype) + + # encoder pass + o_en = self.encoder(x_emb, x_mask) + + # duration predictor pass + o_dr_log = self.duration_predictor(o_en.detach(), x_mask) + + # expand o_en with durations + o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + + # positional encoding + if hasattr(self, 'pos_encoder'): + o_en_ex = self.pos_encoder(o_en_ex) + + # decoder pass + o_de = self.decoder(o_en_ex, y_mask) + + return o_de, o_dr_log.squeeze(1), attn.transpose(1, 2) + + def inference(self, x, x_lengths, g=None): + # pad input to prevent dropping the last word + x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0) + + # [B, T, C] + x_emb = self.emb(x) + # [B, C, T] + x_emb = torch.transpose(x_emb, 1, -1) + + # compute sequence masks + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), + 1).to(x.dtype) + # encoder pass + o_en = self.encoder(x_emb, x_mask) + + # duration predictor pass + o_dr_log = self.duration_predictor(o_en.detach(), x_mask) + o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) + + # output mask + y_mask = torch.unsqueeze(sequence_mask(o_dr.sum(1), None), 1).to(x_mask.dtype) + + # expand o_en with durations + o_en_ex, attn = self.expand_encoder_outputs(o_en, o_dr, x_mask, y_mask) + + # positional encoding + if hasattr(self, 'pos_encoder'): + o_en_ex = self.pos_encoder(o_en_ex) + + # decoder pass + o_de = self.decoder(o_en_ex, y_mask) + + return o_de, attn.transpose(1, 2)