From 946a0c0fb9f9b8543077dff38946e36ed867365a Mon Sep 17 00:00:00 2001 From: erogol Date: Thu, 29 Oct 2020 15:45:50 +0100 Subject: [PATCH] bug fixes for single speaker glow-tts, enable torch based amp. Make amp optional for wavegrad. Bug fixes for synthesis setup for glow-tts --- TTS/bin/train_glow_tts.py | 95 +++++++++++++++------------------- TTS/bin/train_wavegrad.py | 9 ++-- TTS/tts/layers/tacotron2.py | 2 +- TTS/tts/utils/generic_utils.py | 2 +- TTS/tts/utils/synthesis.py | 2 +- TTS/utils/generic_utils.py | 17 ++++-- TTS/vocoder/layers/wavegrad.py | 3 -- 7 files changed, 65 insertions(+), 65 deletions(-) diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py index e30ddc59..9358deb2 100644 --- a/TTS/bin/train_glow_tts.py +++ b/TTS/bin/train_glow_tts.py @@ -15,8 +15,6 @@ from torch.utils.data import DataLoader from TTS.tts.datasets.preprocess import load_meta_data from TTS.tts.datasets.TTSDataset import MyDataset from TTS.tts.layers.losses import GlowTTSLoss -from TTS.tts.utils.distribute import (DistributedSampler, init_distributed, - reduce_tensor) from TTS.tts.utils.generic_utils import setup_model, check_config_tts from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.measures import alignment_diagonal_score @@ -28,7 +26,8 @@ from TTS.utils.audio import AudioProcessor from TTS.utils.console_logger import ConsoleLogger from TTS.utils.generic_utils import (KeepAverage, count_parameters, create_experiment_folder, get_git_branch, - remove_experiment_folder, set_init_dict) + remove_experiment_folder, set_init_dict, + set_amp_context) from TTS.utils.io import copy_config_file, load_config from TTS.utils.radam import RAdam from TTS.utils.tensorboard_logger import TensorboardLogger @@ -36,7 +35,6 @@ from TTS.utils.training import (NoamLR, check_update, setup_torch_training_env) # DISTRIBUTED -from apex.parallel import DistributedDataParallel as DDP_apex from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.utils.data.distributed import DistributedSampler from TTS.utils.distribute import init_distributed, reduce_tensor @@ -157,7 +155,7 @@ def data_depended_init(model, ap, speaker_mapping=None): def train(model, criterion, optimizer, scheduler, - ap, global_step, epoch, amp, speaker_mapping=None): + ap, global_step, epoch, speaker_mapping=None): data_loader = setup_loader(ap, 1, is_val=False, verbose=(epoch == 0), speaker_mapping=speaker_mapping) model.train() @@ -170,6 +168,7 @@ def train(model, criterion, optimizer, scheduler, batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() c_logger.print_train_start() + scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None for num_iter, data in enumerate(data_loader): start_time = time.time() @@ -180,33 +179,38 @@ def train(model, criterion, optimizer, scheduler, loader_time = time.time() - end_time global_step += 1 + optimizer.zero_grad() + + # forward pass model + with set_amp_context(c.mixed_precision): + z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( + text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids) + + # compute loss + loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, + o_dur_log, o_total_dur, text_lengths) + + # backward pass with loss scaling + if c.mixed_precision: + scaler.scale(loss_dict['loss']).backward() + scaler.unscale_(optimizer) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), + c.grad_clip) + scaler.step(optimizer) + scaler.update() + else: + loss_dict['loss'].backward() + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), + c.grad_clip) + optimizer.step() + + + grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True) + optimizer.step() # setup lr if c.noam_schedule: scheduler.step() - optimizer.zero_grad() - - # forward pass model - z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( - text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids) - - # compute loss - loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, - o_dur_log, o_total_dur, text_lengths) - - # backward pass - DISTRIBUTED - if amp is not None: - with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss: - scaled_loss.backward() - else: - loss_dict['loss'].backward() - - if amp: - amp_opt_params = amp.master_params(optimizer) - else: - amp_opt_params = None - grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True, amp_opt_params=amp_opt_params) - optimizer.step() # current_lr current_lr = optimizer.param_groups[0]['lr'] @@ -269,12 +273,12 @@ def train(model, criterion, optimizer, scheduler, if c.checkpoint: # save model save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, - model_loss=loss_dict['loss'], - amp_state_dict=amp.state_dict() if amp else None) + model_loss=loss_dict['loss']) # Diagnostic visualizations # direct pass on model for spec predictions - spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1]) + target_speaker = None if speaker_ids is None else speaker_ids[:1] + spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker) spec_pred = spec_pred.permute(0, 2, 1) gt_spec = mel_input.permute(0, 2, 1) const_spec = spec_pred[0].data.cpu().numpy() @@ -367,10 +371,11 @@ def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping): if args.rank == 0: # Diagnostic visualizations # direct pass on model for spec predictions + target_speaker = None if speaker_ids is None else speaker_ids[:1] if hasattr(model, 'module'): - spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1]) + spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker) else: - spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1]) + spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker) spec_pred = spec_pred.permute(0, 2, 1) gt_spec = mel_input.permute(0, 2, 1) @@ -489,14 +494,6 @@ def main(args): # pylint: disable=redefined-outer-name optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9) criterion = GlowTTSLoss() - if c.apex_amp_level is not None: - # pylint: disable=import-outside-toplevel - from apex import amp - model.cuda() - model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level) - else: - amp = None - if args.restore_path: checkpoint = torch.load(args.restore_path, map_location='cpu') try: @@ -513,9 +510,6 @@ def main(args): # pylint: disable=redefined-outer-name model.load_state_dict(model_dict) del model_dict - if amp and 'amp' in checkpoint: - amp.load_state_dict(checkpoint['amp']) - for group in optimizer.param_groups: group['initial_lr'] = c.lr print(" > Model restored from step %d" % checkpoint['step'], @@ -530,10 +524,7 @@ def main(args): # pylint: disable=redefined-outer-name # DISTRUBUTED if num_gpus > 1: - if c.apex_amp_level is not None: - model = DDP_apex(model) - else: - model = DDP_th(model, device_ids=[args.rank]) + model = DDP_th(model, device_ids=[args.rank]) if c.noam_schedule: scheduler = NoamLR(optimizer, @@ -554,14 +545,14 @@ def main(args): # pylint: disable=redefined-outer-name c_logger.print_epoch_start(epoch, c.epochs) train_avg_loss_dict, global_step = train(model, criterion, optimizer, scheduler, ap, global_step, - epoch, amp, speaker_mapping) + epoch, speaker_mapping) eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=speaker_mapping) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = train_avg_loss_dict['avg_loss'] if c.run_eval: target_loss = eval_avg_loss_dict['avg_loss'] best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r, - OUT_PATH, amp_state_dict=amp.state_dict() if amp else None) + OUT_PATH) if __name__ == '__main__': @@ -614,8 +605,8 @@ if __name__ == '__main__': check_config_tts(c) _ = os.path.dirname(os.path.realpath(__file__)) - if c.apex_amp_level: - print(" > apex AMP level: ", c.apex_amp_level) + if c.mixed_precision: + print(" > Mixed precision enabled.") OUT_PATH = args.continue_path if args.continue_path == '': diff --git a/TTS/bin/train_wavegrad.py b/TTS/bin/train_wavegrad.py index 83e5d78b..13434979 100644 --- a/TTS/bin/train_wavegrad.py +++ b/TTS/bin/train_wavegrad.py @@ -16,7 +16,8 @@ from TTS.utils.console_logger import ConsoleLogger from TTS.utils.distribute import init_distributed from TTS.utils.generic_utils import (KeepAverage, count_parameters, create_experiment_folder, get_git_branch, - remove_experiment_folder, set_init_dict) + remove_experiment_folder, set_init_dict, + set_amp_context) from TTS.utils.io import copy_config_file, load_config from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.training import setup_torch_training_env @@ -101,7 +102,7 @@ def train(model, criterion, optimizer, model.compute_noise_level(noise_schedule['num_steps'], noise_schedule['min_val'], noise_schedule['max_val']) - scaler = torch.cuda.amp.GradScaler() + scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None for num_iter, data in enumerate(data_loader): start_time = time.time() @@ -111,7 +112,7 @@ def train(model, criterion, optimizer, global_step += 1 - with torch.cuda.amp.autocast(): + with set_amp_context(c.mixed_precision): # compute noisy input if hasattr(model, 'module'): noise, x_noisy, noise_scale = model.module.compute_y_n(x) @@ -127,7 +128,7 @@ def train(model, criterion, optimizer, # check nan loss if torch.isnan(loss).any(): - raise RuntimeError(f'Detected NaN loss at step {self.step}.') + raise RuntimeError(f'Detected NaN loss at step {global_step}.') optimizer.zero_grad() diff --git a/TTS/tts/layers/tacotron2.py b/TTS/tts/layers/tacotron2.py index 490f3728..a02db784 100644 --- a/TTS/tts/layers/tacotron2.py +++ b/TTS/tts/layers/tacotron2.py @@ -102,7 +102,7 @@ class Encoder(nn.Module): o = layer(o) o = o.transpose(1, 2) o = nn.utils.rnn.pack_padded_sequence(o, - input_lengths, + input_lengths.cpu(), batch_first=True) self.lstm.flatten_parameters() o, _ = self.lstm(o) diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 2361fa85..d43edcbf 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -248,7 +248,7 @@ def check_config_tts(c): check_argument('use_external_speaker_embedding_file', c, restricted=True if c['use_speaker_embedding'] else False, val_type=bool) check_argument('external_speaker_embedding_file', c, restricted=True if c['use_external_speaker_embedding_file'] else False, val_type=str) check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool) - if c['use_gst']: + if c['model'].lower() in ['tacotron', 'tacotron2'] and c['use_gst']: check_argument('gst', c, restricted=is_tacotron(c), val_type=dict) check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict]) check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index 3d2dd13c..cad1d21f 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -210,7 +210,7 @@ def synthesis(model, """ # GST processing style_mel = None - if CONFIG.use_gst and style_wav is not None: + if 'use_gst' in CONFIG.keys() and CONFIG.use_gst and style_wav is not None: if isinstance(style_wav, dict): style_mel = style_wav else: diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index dcfbbdc3..686a3453 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -1,8 +1,19 @@ -import os -import glob -import shutil import datetime +import glob +import os +import shutil import subprocess +from contextlib import nullcontext + +import torch + + +def set_amp_context(mixed_precision): + if mixed_precision: + cm = torch.cuda.amp.autocast() + else: + cm = nullcontext() + return cm def get_git_branch(): diff --git a/TTS/vocoder/layers/wavegrad.py b/TTS/vocoder/layers/wavegrad.py index a72f2837..c6c20eb5 100644 --- a/TTS/vocoder/layers/wavegrad.py +++ b/TTS/vocoder/layers/wavegrad.py @@ -1,11 +1,8 @@ -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils import weight_norm -from math import log as ln - class Conv1d(nn.Conv1d): def __init__(self, *args, **kwargs):