mirror of https://github.com/coqui-ai/TTS.git
weight norm and torch based amp training for wavegrad
parent
a3213762ae
commit
14c2381207
|
@ -7,10 +7,7 @@ import traceback
|
|||
|
||||
import torch
|
||||
# DISTRIBUTED
|
||||
try:
|
||||
from apex.parallel import DistributedDataParallel as DDP_apex
|
||||
except:
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
@ -82,7 +79,7 @@ def format_test_data(data):
|
|||
|
||||
|
||||
def train(model, criterion, optimizer,
|
||||
scheduler, ap, global_step, epoch, amp):
|
||||
scheduler, ap, global_step, epoch):
|
||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
|
@ -104,6 +101,7 @@ def train(model, criterion, optimizer,
|
|||
model.compute_noise_level(noise_schedule['num_steps'],
|
||||
noise_schedule['min_val'],
|
||||
noise_schedule['max_val'])
|
||||
scaler = torch.cuda.amp.GradScaler()
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
|
@ -113,39 +111,46 @@ def train(model, criterion, optimizer,
|
|||
|
||||
global_step += 1
|
||||
|
||||
# compute noisy input
|
||||
if hasattr(model, 'module'):
|
||||
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
||||
else:
|
||||
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
||||
with torch.cuda.amp.autocast():
|
||||
# compute noisy input
|
||||
if hasattr(model, 'module'):
|
||||
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
||||
else:
|
||||
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
||||
|
||||
# forward pass
|
||||
noise_hat = model(x_noisy, m, noise_scale)
|
||||
# forward pass
|
||||
noise_hat = model(x_noisy, m, noise_scale)
|
||||
|
||||
# compute losses
|
||||
loss = criterion(noise, noise_hat)
|
||||
# if loss.item() > 100:
|
||||
# breakpoint()
|
||||
# compute losses
|
||||
loss = criterion(noise, noise_hat)
|
||||
loss_wavegrad_dict = {'wavegrad_loss':loss}
|
||||
|
||||
# backward pass with loss scaling
|
||||
# check nan loss
|
||||
if torch.isnan(loss).any():
|
||||
raise RuntimeError(f'Detected NaN loss at step {self.step}.')
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if amp is not None:
|
||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
if c.clip_grad > 0:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||
c.clip_grad)
|
||||
optimizer.step()
|
||||
|
||||
# schedule update
|
||||
# schedule update
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
|
||||
# backward pass with loss scaling
|
||||
if c.mixed_precision:
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||
c.clip_grad)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss.backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||
c.clip_grad)
|
||||
optimizer.step()
|
||||
|
||||
|
||||
|
||||
# disconnect loss values
|
||||
loss_dict = dict()
|
||||
for key, value in loss_wavegrad_dict.items():
|
||||
|
@ -175,7 +180,7 @@ def train(model, criterion, optimizer,
|
|||
'step_time': [step_time, 2],
|
||||
'loader_time': [loader_time, 4],
|
||||
"current_lr": current_lr,
|
||||
"grad_norm": grad_norm
|
||||
"grad_norm": grad_norm.item()
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
||||
log_dict, loss_dict, keep_avg.avg_values)
|
||||
|
@ -185,7 +190,7 @@ def train(model, criterion, optimizer,
|
|||
if global_step % 10 == 0:
|
||||
iter_stats = {
|
||||
"lr": current_lr,
|
||||
"grad_norm": grad_norm,
|
||||
"grad_norm": grad_norm.item(),
|
||||
"step_time": step_time
|
||||
}
|
||||
iter_stats.update(loss_dict)
|
||||
|
@ -335,16 +340,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
# setup optimizers
|
||||
optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
||||
|
||||
# DISTRIBUTED
|
||||
if c.apex_amp_level is not None:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from apex import amp
|
||||
model.cuda()
|
||||
# optimizer.cuda()
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
|
||||
else:
|
||||
amp = None
|
||||
|
||||
# schedulers
|
||||
scheduler = None
|
||||
if 'lr_scheduler' in c:
|
||||
|
@ -374,10 +369,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
# DISTRUBUTED
|
||||
if amp and 'amp' in checkpoint:
|
||||
amp.load_state_dict(checkpoint['amp'])
|
||||
|
||||
# reset lr if not countinuining training.
|
||||
for group in optimizer.param_groups:
|
||||
group['lr'] = c.lr
|
||||
|
@ -410,7 +401,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
_, global_step = train(model, criterion, optimizer,
|
||||
scheduler, ap, global_step,
|
||||
epoch, amp)
|
||||
epoch)
|
||||
eval_avg_loss_dict = evaluate(model, criterion, ap,
|
||||
global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
|
@ -426,8 +417,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=eval_avg_loss_dict,
|
||||
amp_state_dict=amp.state_dict() if amp else None)
|
||||
model_losses=eval_avg_loss_dict)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -481,8 +471,8 @@ if __name__ == '__main__':
|
|||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
# DISTRIBUTED
|
||||
if c.apex_amp_level is not None:
|
||||
print(" > apex AMP level: ", c.apex_amp_level)
|
||||
if c.mixed_precision:
|
||||
print(" > Mixed precision is enabled")
|
||||
|
||||
OUT_PATH = args.continue_path
|
||||
if args.continue_path == '':
|
||||
|
|
|
@ -34,7 +34,7 @@
|
|||
},
|
||||
|
||||
// DISTRIBUTED TRAINING
|
||||
"apex_amp_level": "O1", // APEX amp optimization level. "O1" is currently supported.
|
||||
"mixed_precision": true, // enable torch mixed precision training (true, false)
|
||||
"distributed":{
|
||||
"backend": "nccl",
|
||||
"url": "tcp:\/\/localhost:54322"
|
||||
|
@ -98,7 +98,7 @@
|
|||
// TENSORBOARD and LOGGING
|
||||
"print_step": 50, // Number of steps to log traning on console.
|
||||
"print_eval": false, // If True, it prints loss values for each step in eval run.
|
||||
"save_step": 10000, // Number of training steps expected to plot training stats on TB and save model checkpoints.
|
||||
"save_step": 5000, // Number of training steps expected to plot training stats on TB and save model checkpoints.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from math import log as ln
|
||||
|
||||
|
@ -13,36 +14,59 @@ class Conv1d(nn.Conv1d):
|
|||
nn.init.zeros_(self.bias)
|
||||
|
||||
|
||||
# class PositionalEncoding(nn.Module):
|
||||
# def __init__(self, n_channels):
|
||||
# super().__init__()
|
||||
# self.n_channels = n_channels
|
||||
# self.length = n_channels // 2
|
||||
# assert n_channels % 2 == 0
|
||||
|
||||
# def forward(self, x, noise_level):
|
||||
# """
|
||||
# Shapes:
|
||||
# x: B x C x T
|
||||
# noise_level: B
|
||||
# """
|
||||
# return (x + self.encoding(noise_level)[:, :, None])
|
||||
|
||||
# def encoding(self, noise_level):
|
||||
# step = torch.arange(
|
||||
# self.length, dtype=noise_level.dtype, device=noise_level.device) / self.length
|
||||
# encoding = noise_level.unsqueeze(1) * torch.exp(
|
||||
# -ln(1e4) * step.unsqueeze(0))
|
||||
# encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1)
|
||||
# return encoding
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, n_channels):
|
||||
def __init__(self, n_channels, max_len=10000):
|
||||
super().__init__()
|
||||
self.n_channels = n_channels
|
||||
self.length = n_channels // 2
|
||||
assert n_channels % 2 == 0
|
||||
self.max_len = max_len
|
||||
self.C = 5000
|
||||
self.pe = torch.zeros(0, 0)
|
||||
|
||||
def forward(self, x, noise_level):
|
||||
"""
|
||||
Shapes:
|
||||
x: B x C x T
|
||||
noise_level: B
|
||||
"""
|
||||
return (x + self.encoding(noise_level)[:, :, None])
|
||||
if x.shape[2] > self.pe.shape[1]:
|
||||
self.init_pe_matrix(x.shape[1] ,x.shape[2], x)
|
||||
return x + noise_level[..., None, None] + self.pe[:, :x.size(2)].repeat(x.shape[0], 1, 1) / self.C
|
||||
|
||||
def encoding(self, noise_level):
|
||||
step = torch.arange(
|
||||
self.length, dtype=noise_level.dtype, device=noise_level.device) / self.length
|
||||
encoding = noise_level.unsqueeze(1) * torch.exp(
|
||||
-ln(1e4) * step.unsqueeze(0))
|
||||
encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1)
|
||||
return encoding
|
||||
def init_pe_matrix(self, n_channels, max_len, x):
|
||||
pe = torch.zeros(max_len, n_channels)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.pow(10000, torch.arange(0, n_channels, 2).float() / n_channels)
|
||||
|
||||
pe[:, 0::2] = torch.sin(position / div_term)
|
||||
pe[:, 1::2] = torch.cos(position / div_term)
|
||||
self.pe = pe.transpose(0, 1).to(x)
|
||||
|
||||
|
||||
class FiLM(nn.Module):
|
||||
def __init__(self, input_size, output_size):
|
||||
super().__init__()
|
||||
self.encoding = PositionalEncoding(input_size)
|
||||
self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1)
|
||||
self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1)
|
||||
self.input_conv = weight_norm(nn.Conv1d(input_size, input_size, 3, padding=1))
|
||||
self.output_conv = weight_norm(nn.Conv1d(input_size, output_size * 2, 3, padding=1))
|
||||
self.ini_parameters()
|
||||
|
||||
def ini_parameters(self):
|
||||
|
@ -72,30 +96,30 @@ class UBlock(nn.Module):
|
|||
assert len(dilation) == 4
|
||||
|
||||
self.factor = factor
|
||||
self.block1 = Conv1d(input_size, hidden_size, 1)
|
||||
self.block1 = weight_norm(Conv1d(input_size, hidden_size, 1))
|
||||
self.block2 = nn.ModuleList([
|
||||
Conv1d(input_size,
|
||||
weight_norm(Conv1d(input_size,
|
||||
hidden_size,
|
||||
3,
|
||||
dilation=dilation[0],
|
||||
padding=dilation[0]),
|
||||
Conv1d(hidden_size,
|
||||
padding=dilation[0])),
|
||||
weight_norm(Conv1d(hidden_size,
|
||||
hidden_size,
|
||||
3,
|
||||
dilation=dilation[1],
|
||||
padding=dilation[1])
|
||||
padding=dilation[1]))
|
||||
])
|
||||
self.block3 = nn.ModuleList([
|
||||
Conv1d(hidden_size,
|
||||
weight_norm(Conv1d(hidden_size,
|
||||
hidden_size,
|
||||
3,
|
||||
dilation=dilation[2],
|
||||
padding=dilation[2]),
|
||||
Conv1d(hidden_size,
|
||||
padding=dilation[2])),
|
||||
weight_norm(Conv1d(hidden_size,
|
||||
hidden_size,
|
||||
3,
|
||||
dilation=dilation[3],
|
||||
padding=dilation[3])
|
||||
padding=dilation[3]))
|
||||
])
|
||||
|
||||
def forward(self, x, shift, scale):
|
||||
|
@ -129,11 +153,11 @@ class DBlock(nn.Module):
|
|||
def __init__(self, input_size, hidden_size, factor):
|
||||
super().__init__()
|
||||
self.factor = factor
|
||||
self.residual_dense = Conv1d(input_size, hidden_size, 1)
|
||||
self.residual_dense = weight_norm(Conv1d(input_size, hidden_size, 1))
|
||||
self.conv = nn.ModuleList([
|
||||
Conv1d(input_size, hidden_size, 3, dilation=1, padding=1),
|
||||
Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2),
|
||||
Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4),
|
||||
weight_norm(Conv1d(input_size, hidden_size, 3, dilation=1, padding=1)),
|
||||
weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2)),
|
||||
weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4)),
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
|
|
Loading…
Reference in New Issue