mirror of https://github.com/coqui-ai/TTS.git
bug fixes for single speaker glow-tts, enable torch based amp. Make amp optional for wavegrad. Bug fixes for synthesis setup for glow-tts
parent
14c2381207
commit
946a0c0fb9
|
@ -15,8 +15,6 @@ from torch.utils.data import DataLoader
|
|||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||
from TTS.tts.layers.losses import GlowTTSLoss
|
||||
from TTS.tts.utils.distribute import (DistributedSampler, init_distributed,
|
||||
reduce_tensor)
|
||||
from TTS.tts.utils.generic_utils import setup_model, check_config_tts
|
||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
|
@ -28,7 +26,8 @@ from TTS.utils.audio import AudioProcessor
|
|||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
create_experiment_folder, get_git_branch,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
remove_experiment_folder, set_init_dict,
|
||||
set_amp_context)
|
||||
from TTS.utils.io import copy_config_file, load_config
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
|
@ -36,7 +35,6 @@ from TTS.utils.training import (NoamLR, check_update,
|
|||
setup_torch_training_env)
|
||||
|
||||
# DISTRIBUTED
|
||||
from apex.parallel import DistributedDataParallel as DDP_apex
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from TTS.utils.distribute import init_distributed, reduce_tensor
|
||||
|
@ -157,7 +155,7 @@ def data_depended_init(model, ap, speaker_mapping=None):
|
|||
|
||||
|
||||
def train(model, criterion, optimizer, scheduler,
|
||||
ap, global_step, epoch, amp, speaker_mapping=None):
|
||||
ap, global_step, epoch, speaker_mapping=None):
|
||||
data_loader = setup_loader(ap, 1, is_val=False,
|
||||
verbose=(epoch == 0), speaker_mapping=speaker_mapping)
|
||||
model.train()
|
||||
|
@ -170,6 +168,7 @@ def train(model, criterion, optimizer, scheduler,
|
|||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
end_time = time.time()
|
||||
c_logger.print_train_start()
|
||||
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
|
@ -180,33 +179,38 @@ def train(model, criterion, optimizer, scheduler,
|
|||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward pass model
|
||||
with set_amp_context(c.mixed_precision):
|
||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
||||
o_dur_log, o_total_dur, text_lengths)
|
||||
|
||||
# backward pass with loss scaling
|
||||
if c.mixed_precision:
|
||||
scaler.scale(loss_dict['loss']).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||
c.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss_dict['loss'].backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||
c.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
|
||||
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
|
||||
optimizer.step()
|
||||
|
||||
# setup lr
|
||||
if c.noam_schedule:
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward pass model
|
||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
||||
o_dur_log, o_total_dur, text_lengths)
|
||||
|
||||
# backward pass - DISTRIBUTED
|
||||
if amp is not None:
|
||||
with amp.scale_loss(loss_dict['loss'], optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss_dict['loss'].backward()
|
||||
|
||||
if amp:
|
||||
amp_opt_params = amp.master_params(optimizer)
|
||||
else:
|
||||
amp_opt_params = None
|
||||
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True, amp_opt_params=amp_opt_params)
|
||||
optimizer.step()
|
||||
|
||||
# current_lr
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
|
@ -269,12 +273,12 @@ def train(model, criterion, optimizer, scheduler,
|
|||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
|
||||
model_loss=loss_dict['loss'],
|
||||
amp_state_dict=amp.state_dict() if amp else None)
|
||||
model_loss=loss_dict['loss'])
|
||||
|
||||
# Diagnostic visualizations
|
||||
# direct pass on model for spec predictions
|
||||
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
|
||||
target_speaker = None if speaker_ids is None else speaker_ids[:1]
|
||||
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
||||
spec_pred = spec_pred.permute(0, 2, 1)
|
||||
gt_spec = mel_input.permute(0, 2, 1)
|
||||
const_spec = spec_pred[0].data.cpu().numpy()
|
||||
|
@ -367,10 +371,11 @@ def evaluate(model, criterion, ap, global_step, epoch, speaker_mapping):
|
|||
if args.rank == 0:
|
||||
# Diagnostic visualizations
|
||||
# direct pass on model for spec predictions
|
||||
target_speaker = None if speaker_ids is None else speaker_ids[:1]
|
||||
if hasattr(model, 'module'):
|
||||
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
|
||||
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
||||
else:
|
||||
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=speaker_ids[:1])
|
||||
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
||||
spec_pred = spec_pred.permute(0, 2, 1)
|
||||
gt_spec = mel_input.permute(0, 2, 1)
|
||||
|
||||
|
@ -489,14 +494,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
|
||||
criterion = GlowTTSLoss()
|
||||
|
||||
if c.apex_amp_level is not None:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from apex import amp
|
||||
model.cuda()
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
|
||||
else:
|
||||
amp = None
|
||||
|
||||
if args.restore_path:
|
||||
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||
try:
|
||||
|
@ -513,9 +510,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
if amp and 'amp' in checkpoint:
|
||||
amp.load_state_dict(checkpoint['amp'])
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group['initial_lr'] = c.lr
|
||||
print(" > Model restored from step %d" % checkpoint['step'],
|
||||
|
@ -530,10 +524,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
if c.apex_amp_level is not None:
|
||||
model = DDP_apex(model)
|
||||
else:
|
||||
model = DDP_th(model, device_ids=[args.rank])
|
||||
model = DDP_th(model, device_ids=[args.rank])
|
||||
|
||||
if c.noam_schedule:
|
||||
scheduler = NoamLR(optimizer,
|
||||
|
@ -554,14 +545,14 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
train_avg_loss_dict, global_step = train(model, criterion, optimizer,
|
||||
scheduler, ap, global_step,
|
||||
epoch, amp, speaker_mapping)
|
||||
epoch, speaker_mapping)
|
||||
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch, speaker_mapping=speaker_mapping)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = train_avg_loss_dict['avg_loss']
|
||||
if c.run_eval:
|
||||
target_loss = eval_avg_loss_dict['avg_loss']
|
||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
||||
OUT_PATH, amp_state_dict=amp.state_dict() if amp else None)
|
||||
OUT_PATH)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -614,8 +605,8 @@ if __name__ == '__main__':
|
|||
check_config_tts(c)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
if c.apex_amp_level:
|
||||
print(" > apex AMP level: ", c.apex_amp_level)
|
||||
if c.mixed_precision:
|
||||
print(" > Mixed precision enabled.")
|
||||
|
||||
OUT_PATH = args.continue_path
|
||||
if args.continue_path == '':
|
||||
|
|
|
@ -16,7 +16,8 @@ from TTS.utils.console_logger import ConsoleLogger
|
|||
from TTS.utils.distribute import init_distributed
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
create_experiment_folder, get_git_branch,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
remove_experiment_folder, set_init_dict,
|
||||
set_amp_context)
|
||||
from TTS.utils.io import copy_config_file, load_config
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.training import setup_torch_training_env
|
||||
|
@ -101,7 +102,7 @@ def train(model, criterion, optimizer,
|
|||
model.compute_noise_level(noise_schedule['num_steps'],
|
||||
noise_schedule['min_val'],
|
||||
noise_schedule['max_val'])
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
|
@ -111,7 +112,7 @@ def train(model, criterion, optimizer,
|
|||
|
||||
global_step += 1
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
with set_amp_context(c.mixed_precision):
|
||||
# compute noisy input
|
||||
if hasattr(model, 'module'):
|
||||
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
||||
|
@ -127,7 +128,7 @@ def train(model, criterion, optimizer,
|
|||
|
||||
# check nan loss
|
||||
if torch.isnan(loss).any():
|
||||
raise RuntimeError(f'Detected NaN loss at step {self.step}.')
|
||||
raise RuntimeError(f'Detected NaN loss at step {global_step}.')
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
|
|
@ -102,7 +102,7 @@ class Encoder(nn.Module):
|
|||
o = layer(o)
|
||||
o = o.transpose(1, 2)
|
||||
o = nn.utils.rnn.pack_padded_sequence(o,
|
||||
input_lengths,
|
||||
input_lengths.cpu(),
|
||||
batch_first=True)
|
||||
self.lstm.flatten_parameters()
|
||||
o, _ = self.lstm(o)
|
||||
|
|
|
@ -248,7 +248,7 @@ def check_config_tts(c):
|
|||
check_argument('use_external_speaker_embedding_file', c, restricted=True if c['use_speaker_embedding'] else False, val_type=bool)
|
||||
check_argument('external_speaker_embedding_file', c, restricted=True if c['use_external_speaker_embedding_file'] else False, val_type=str)
|
||||
check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool)
|
||||
if c['use_gst']:
|
||||
if c['model'].lower() in ['tacotron', 'tacotron2'] and c['use_gst']:
|
||||
check_argument('gst', c, restricted=is_tacotron(c), val_type=dict)
|
||||
check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict])
|
||||
check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)
|
||||
|
|
|
@ -210,7 +210,7 @@ def synthesis(model,
|
|||
"""
|
||||
# GST processing
|
||||
style_mel = None
|
||||
if CONFIG.use_gst and style_wav is not None:
|
||||
if 'use_gst' in CONFIG.keys() and CONFIG.use_gst and style_wav is not None:
|
||||
if isinstance(style_wav, dict):
|
||||
style_mel = style_wav
|
||||
else:
|
||||
|
|
|
@ -1,8 +1,19 @@
|
|||
import os
|
||||
import glob
|
||||
import shutil
|
||||
import datetime
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def set_amp_context(mixed_precision):
|
||||
if mixed_precision:
|
||||
cm = torch.cuda.amp.autocast()
|
||||
else:
|
||||
cm = nullcontext()
|
||||
return cm
|
||||
|
||||
|
||||
def get_git_branch():
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from math import log as ln
|
||||
|
||||
|
||||
class Conv1d(nn.Conv1d):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
Loading…
Reference in New Issue