bug fixes for single speaker glow-tts, enable torch based amp. Make amp optional for wavegrad. Bug fixes for synthesis setup for glow-tts

pull/10/head
erogol 2020-10-29 15:45:50 +01:00
parent 14c2381207
commit 946a0c0fb9
7 changed files with 65 additions and 65 deletions

View File

@ -15,8 +15,6 @@ from torch.utils.data import DataLoader
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.layers.losses import GlowTTSLoss
from TTS.tts.utils.distribute import (DistributedSampler, init_distributed,
reduce_tensor)
from TTS.tts.utils.generic_utils import setup_model, check_config_tts
from TTS.tts.utils.io import save_best_model, save_checkpoint
from TTS.tts.utils.measures import alignment_diagonal_score
@ -28,7 +26,8 @@ from TTS.utils.audio import AudioProcessor
from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict)
remove_experiment_folder, set_init_dict,
set_amp_context)
from TTS.utils.io import copy_config_file, load_config
from TTS.utils.radam import RAdam
from TTS.utils.tensorboard_logger import TensorboardLogger
@ -36,7 +35,6 @@ from TTS.utils.training import (NoamLR, check_update,
setup_torch_training_env)
# DISTRIBUTED
from apex.parallel import DistributedDataParallel as DDP_apex
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data.distributed import DistributedSampler
from TTS.utils.distribute import init_distributed, reduce_tensor
@ -157,7 +155,7 @@ def data_depended_init(model, ap, speaker_mapping=None):
def train(model, criterion, optimizer, scheduler,
ap, global_step, epoch, amp, speaker_mapping=None):
ap, global_step, epoch, speaker_mapping=None):
data_loader = setup_loader(ap, 1, is_val=False,
verbose=(epoch == 0), speaker_mapping=speaker_mapping)
model.train()
@ -170,6 +168,7 @@ def train(model, criterion, optimizer, scheduler,
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
end_time = time.time()
c_logger.print_train_start()
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
for num_iter, data in enumerate(data_loader):
start_time = time.time()
@ -180,33 +179,38 @@ def train(model, criterion, optimizer, scheduler,
loader_time = time.time() - end_time
global_step += 1
optimizer.zero_grad()
# forward pass model
with set_amp_context(c.mixed_precision):
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
# compute loss
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
o_dur_log, o_total_dur, text_lengths)
# backward pass with loss scaling
if c.mixed_precision:
scaler.scale(loss_dict['loss']).backward()
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.grad_clip)
scaler.step(optimizer)
scaler.update()
else:
loss_dict['loss'].backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.grad_clip)
optimizer.step()
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
optimizer.step()
# setup lr
if c.noam_schedule:
scheduler.step()
optimizer.zero_grad()
# forward pass model
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
# compute loss
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
o_dur_log, o_total_dur, text_lengths)
# backward pass - DISTRIBUTED
if amp is not None:
with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss_dict['loss'].backward()
if amp:
amp_opt_params = amp.master_params(optimizer)
else:
amp_opt_params = None
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True, amp_opt_params=amp_opt_params)
optimizer.step()
# current_lr
current_lr = optimizer.param_groups[0]['lr']
@ -269,12 +273,12 @@ def train(model, criterion, optimizer, scheduler,
if c.checkpoint:
# save model
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
model_loss=loss_dict['loss'],
amp_state_dict=amp.state_dict() if amp else None)
model_loss=loss_dict['loss'])
# Diagnostic visualizations
# direct pass on model for spec predictions
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
target_speaker = None if speaker_ids is None else speaker_ids[:1]
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker)
spec_pred = spec_pred.permute(0, 2, 1)
gt_spec = mel_input.permute(0, 2, 1)
const_spec = spec_pred[0].data.cpu().numpy()
@ -367,10 +371,11 @@ def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping):
if args.rank == 0:
# Diagnostic visualizations
# direct pass on model for spec predictions
target_speaker = None if speaker_ids is None else speaker_ids[:1]
if hasattr(model, 'module'):
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker)
else:
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker)
spec_pred = spec_pred.permute(0, 2, 1)
gt_spec = mel_input.permute(0, 2, 1)
@ -489,14 +494,6 @@ def main(args): # pylint: disable=redefined-outer-name
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
criterion = GlowTTSLoss()
if c.apex_amp_level is not None:
# pylint: disable=import-outside-toplevel
from apex import amp
model.cuda()
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
else:
amp = None
if args.restore_path:
checkpoint = torch.load(args.restore_path, map_location='cpu')
try:
@ -513,9 +510,6 @@ def main(args): # pylint: disable=redefined-outer-name
model.load_state_dict(model_dict)
del model_dict
if amp and 'amp' in checkpoint:
amp.load_state_dict(checkpoint['amp'])
for group in optimizer.param_groups:
group['initial_lr'] = c.lr
print(" > Model restored from step %d" % checkpoint['step'],
@ -530,10 +524,7 @@ def main(args): # pylint: disable=redefined-outer-name
# DISTRUBUTED
if num_gpus > 1:
if c.apex_amp_level is not None:
model = DDP_apex(model)
else:
model = DDP_th(model, device_ids=[args.rank])
model = DDP_th(model, device_ids=[args.rank])
if c.noam_schedule:
scheduler = NoamLR(optimizer,
@ -554,14 +545,14 @@ def main(args): # pylint: disable=redefined-outer-name
c_logger.print_epoch_start(epoch, c.epochs)
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
scheduler, ap, global_step,
epoch, amp, speaker_mapping)
epoch, speaker_mapping)
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=speaker_mapping)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_loss']
if c.run_eval:
target_loss = eval_avg_loss_dict['avg_loss']
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
OUT_PATH, amp_state_dict=amp.state_dict() if amp else None)
OUT_PATH)
if __name__ == '__main__':
@ -614,8 +605,8 @@ if __name__ == '__main__':
check_config_tts(c)
_ = os.path.dirname(os.path.realpath(__file__))
if c.apex_amp_level:
print(" > apex AMP level: ", c.apex_amp_level)
if c.mixed_precision:
print(" > Mixed precision enabled.")
OUT_PATH = args.continue_path
if args.continue_path == '':

View File

@ -16,7 +16,8 @@ from TTS.utils.console_logger import ConsoleLogger
from TTS.utils.distribute import init_distributed
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict)
remove_experiment_folder, set_init_dict,
set_amp_context)
from TTS.utils.io import copy_config_file, load_config
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import setup_torch_training_env
@ -101,7 +102,7 @@ def train(model, criterion, optimizer,
model.compute_noise_level(noise_schedule['num_steps'],
noise_schedule['min_val'],
noise_schedule['max_val'])
scaler = torch.cuda.amp.GradScaler()
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
for num_iter, data in enumerate(data_loader):
start_time = time.time()
@ -111,7 +112,7 @@ def train(model, criterion, optimizer,
global_step += 1
with torch.cuda.amp.autocast():
with set_amp_context(c.mixed_precision):
# compute noisy input
if hasattr(model, 'module'):
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
@ -127,7 +128,7 @@ def train(model, criterion, optimizer,
# check nan loss
if torch.isnan(loss).any():
raise RuntimeError(f'Detected NaN loss at step {self.step}.')
raise RuntimeError(f'Detected NaN loss at step {global_step}.')
optimizer.zero_grad()

View File

@ -102,7 +102,7 @@ class Encoder(nn.Module):
o = layer(o)
o = o.transpose(1, 2)
o = nn.utils.rnn.pack_padded_sequence(o,
input_lengths,
input_lengths.cpu(),
batch_first=True)
self.lstm.flatten_parameters()
o, _ = self.lstm(o)

View File

@ -248,7 +248,7 @@ def check_config_tts(c):
check_argument('use_external_speaker_embedding_file', c, restricted=True if c['use_speaker_embedding'] else False, val_type=bool)
check_argument('external_speaker_embedding_file', c, restricted=True if c['use_external_speaker_embedding_file'] else False, val_type=str)
check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool)
if c['use_gst']:
if c['model'].lower() in ['tacotron', 'tacotron2'] and c['use_gst']:
check_argument('gst', c, restricted=is_tacotron(c), val_type=dict)
check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict])
check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)

View File

@ -210,7 +210,7 @@ def synthesis(model,
"""
# GST processing
style_mel = None
if CONFIG.use_gst and style_wav is not None:
if 'use_gst' in CONFIG.keys() and CONFIG.use_gst and style_wav is not None:
if isinstance(style_wav, dict):
style_mel = style_wav
else:

View File

@ -1,8 +1,19 @@
import os
import glob
import shutil
import datetime
import glob
import os
import shutil
import subprocess
from contextlib import nullcontext
import torch
def set_amp_context(mixed_precision):
if mixed_precision:
cm = torch.cuda.amp.autocast()
else:
cm = nullcontext()
return cm
def get_git_branch():

View File

@ -1,11 +1,8 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from math import log as ln
class Conv1d(nn.Conv1d):
def __init__(self, *args, **kwargs):