mirror of https://github.com/coqui-ai/TTS.git
weight norm and torch based amp training for wavegrad
parent
a3213762ae
commit
14c2381207
|
@ -7,10 +7,7 @@ import traceback
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
# DISTRIBUTED
|
# DISTRIBUTED
|
||||||
try:
|
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||||
from apex.parallel import DistributedDataParallel as DDP_apex
|
|
||||||
except:
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
|
||||||
from torch.optim import Adam
|
from torch.optim import Adam
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
@ -82,7 +79,7 @@ def format_test_data(data):
|
||||||
|
|
||||||
|
|
||||||
def train(model, criterion, optimizer,
|
def train(model, criterion, optimizer,
|
||||||
scheduler, ap, global_step, epoch, amp):
|
scheduler, ap, global_step, epoch):
|
||||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||||
model.train()
|
model.train()
|
||||||
epoch_time = 0
|
epoch_time = 0
|
||||||
|
@ -104,6 +101,7 @@ def train(model, criterion, optimizer,
|
||||||
model.compute_noise_level(noise_schedule['num_steps'],
|
model.compute_noise_level(noise_schedule['num_steps'],
|
||||||
noise_schedule['min_val'],
|
noise_schedule['min_val'],
|
||||||
noise_schedule['max_val'])
|
noise_schedule['max_val'])
|
||||||
|
scaler = torch.cuda.amp.GradScaler()
|
||||||
for num_iter, data in enumerate(data_loader):
|
for num_iter, data in enumerate(data_loader):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
@ -113,39 +111,46 @@ def train(model, criterion, optimizer,
|
||||||
|
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
# compute noisy input
|
with torch.cuda.amp.autocast():
|
||||||
if hasattr(model, 'module'):
|
# compute noisy input
|
||||||
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
if hasattr(model, 'module'):
|
||||||
else:
|
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
||||||
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
else:
|
||||||
|
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
noise_hat = model(x_noisy, m, noise_scale)
|
noise_hat = model(x_noisy, m, noise_scale)
|
||||||
|
|
||||||
# compute losses
|
# compute losses
|
||||||
loss = criterion(noise, noise_hat)
|
loss = criterion(noise, noise_hat)
|
||||||
# if loss.item() > 100:
|
|
||||||
# breakpoint()
|
|
||||||
loss_wavegrad_dict = {'wavegrad_loss':loss}
|
loss_wavegrad_dict = {'wavegrad_loss':loss}
|
||||||
|
|
||||||
# backward pass with loss scaling
|
# check nan loss
|
||||||
|
if torch.isnan(loss).any():
|
||||||
|
raise RuntimeError(f'Detected NaN loss at step {self.step}.')
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
if amp is not None:
|
# schedule update
|
||||||
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
|
||||||
scaled_loss.backward()
|
|
||||||
else:
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
if c.clip_grad > 0:
|
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
|
||||||
c.clip_grad)
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
# schedule update
|
|
||||||
if scheduler is not None:
|
if scheduler is not None:
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
|
# backward pass with loss scaling
|
||||||
|
if c.mixed_precision:
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||||
|
c.clip_grad)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||||
|
c.clip_grad)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# disconnect loss values
|
# disconnect loss values
|
||||||
loss_dict = dict()
|
loss_dict = dict()
|
||||||
for key, value in loss_wavegrad_dict.items():
|
for key, value in loss_wavegrad_dict.items():
|
||||||
|
@ -175,7 +180,7 @@ def train(model, criterion, optimizer,
|
||||||
'step_time': [step_time, 2],
|
'step_time': [step_time, 2],
|
||||||
'loader_time': [loader_time, 4],
|
'loader_time': [loader_time, 4],
|
||||||
"current_lr": current_lr,
|
"current_lr": current_lr,
|
||||||
"grad_norm": grad_norm
|
"grad_norm": grad_norm.item()
|
||||||
}
|
}
|
||||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
||||||
log_dict, loss_dict, keep_avg.avg_values)
|
log_dict, loss_dict, keep_avg.avg_values)
|
||||||
|
@ -185,7 +190,7 @@ def train(model, criterion, optimizer,
|
||||||
if global_step % 10 == 0:
|
if global_step % 10 == 0:
|
||||||
iter_stats = {
|
iter_stats = {
|
||||||
"lr": current_lr,
|
"lr": current_lr,
|
||||||
"grad_norm": grad_norm,
|
"grad_norm": grad_norm.item(),
|
||||||
"step_time": step_time
|
"step_time": step_time
|
||||||
}
|
}
|
||||||
iter_stats.update(loss_dict)
|
iter_stats.update(loss_dict)
|
||||||
|
@ -335,16 +340,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
# setup optimizers
|
# setup optimizers
|
||||||
optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
||||||
|
|
||||||
# DISTRIBUTED
|
|
||||||
if c.apex_amp_level is not None:
|
|
||||||
# pylint: disable=import-outside-toplevel
|
|
||||||
from apex import amp
|
|
||||||
model.cuda()
|
|
||||||
# optimizer.cuda()
|
|
||||||
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
|
|
||||||
else:
|
|
||||||
amp = None
|
|
||||||
|
|
||||||
# schedulers
|
# schedulers
|
||||||
scheduler = None
|
scheduler = None
|
||||||
if 'lr_scheduler' in c:
|
if 'lr_scheduler' in c:
|
||||||
|
@ -374,10 +369,6 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
model.load_state_dict(model_dict)
|
model.load_state_dict(model_dict)
|
||||||
del model_dict
|
del model_dict
|
||||||
|
|
||||||
# DISTRUBUTED
|
|
||||||
if amp and 'amp' in checkpoint:
|
|
||||||
amp.load_state_dict(checkpoint['amp'])
|
|
||||||
|
|
||||||
# reset lr if not countinuining training.
|
# reset lr if not countinuining training.
|
||||||
for group in optimizer.param_groups:
|
for group in optimizer.param_groups:
|
||||||
group['lr'] = c.lr
|
group['lr'] = c.lr
|
||||||
|
@ -410,7 +401,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
c_logger.print_epoch_start(epoch, c.epochs)
|
c_logger.print_epoch_start(epoch, c.epochs)
|
||||||
_, global_step = train(model, criterion, optimizer,
|
_, global_step = train(model, criterion, optimizer,
|
||||||
scheduler, ap, global_step,
|
scheduler, ap, global_step,
|
||||||
epoch, amp)
|
epoch)
|
||||||
eval_avg_loss_dict = evaluate(model, criterion, ap,
|
eval_avg_loss_dict = evaluate(model, criterion, ap,
|
||||||
global_step, epoch)
|
global_step, epoch)
|
||||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||||
|
@ -426,8 +417,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
global_step,
|
global_step,
|
||||||
epoch,
|
epoch,
|
||||||
OUT_PATH,
|
OUT_PATH,
|
||||||
model_losses=eval_avg_loss_dict,
|
model_losses=eval_avg_loss_dict)
|
||||||
amp_state_dict=amp.state_dict() if amp else None)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -481,8 +471,8 @@ if __name__ == '__main__':
|
||||||
_ = os.path.dirname(os.path.realpath(__file__))
|
_ = os.path.dirname(os.path.realpath(__file__))
|
||||||
|
|
||||||
# DISTRIBUTED
|
# DISTRIBUTED
|
||||||
if c.apex_amp_level is not None:
|
if c.mixed_precision:
|
||||||
print(" > apex AMP level: ", c.apex_amp_level)
|
print(" > Mixed precision is enabled")
|
||||||
|
|
||||||
OUT_PATH = args.continue_path
|
OUT_PATH = args.continue_path
|
||||||
if args.continue_path == '':
|
if args.continue_path == '':
|
||||||
|
|
|
@ -34,7 +34,7 @@
|
||||||
},
|
},
|
||||||
|
|
||||||
// DISTRIBUTED TRAINING
|
// DISTRIBUTED TRAINING
|
||||||
"apex_amp_level": "O1", // APEX amp optimization level. "O1" is currently supported.
|
"mixed_precision": true, // enable torch mixed precision training (true, false)
|
||||||
"distributed":{
|
"distributed":{
|
||||||
"backend": "nccl",
|
"backend": "nccl",
|
||||||
"url": "tcp:\/\/localhost:54322"
|
"url": "tcp:\/\/localhost:54322"
|
||||||
|
@ -98,7 +98,7 @@
|
||||||
// TENSORBOARD and LOGGING
|
// TENSORBOARD and LOGGING
|
||||||
"print_step": 50, // Number of steps to log traning on console.
|
"print_step": 50, // Number of steps to log traning on console.
|
||||||
"print_eval": false, // If True, it prints loss values for each step in eval run.
|
"print_eval": false, // If True, it prints loss values for each step in eval run.
|
||||||
"save_step": 10000, // Number of training steps expected to plot training stats on TB and save model checkpoints.
|
"save_step": 5000, // Number of training steps expected to plot training stats on TB and save model checkpoints.
|
||||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||||
"tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
"tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.utils import weight_norm
|
||||||
|
|
||||||
from math import log as ln
|
from math import log as ln
|
||||||
|
|
||||||
|
@ -13,36 +14,59 @@ class Conv1d(nn.Conv1d):
|
||||||
nn.init.zeros_(self.bias)
|
nn.init.zeros_(self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
# class PositionalEncoding(nn.Module):
|
||||||
|
# def __init__(self, n_channels):
|
||||||
|
# super().__init__()
|
||||||
|
# self.n_channels = n_channels
|
||||||
|
# self.length = n_channels // 2
|
||||||
|
# assert n_channels % 2 == 0
|
||||||
|
|
||||||
|
# def forward(self, x, noise_level):
|
||||||
|
# """
|
||||||
|
# Shapes:
|
||||||
|
# x: B x C x T
|
||||||
|
# noise_level: B
|
||||||
|
# """
|
||||||
|
# return (x + self.encoding(noise_level)[:, :, None])
|
||||||
|
|
||||||
|
# def encoding(self, noise_level):
|
||||||
|
# step = torch.arange(
|
||||||
|
# self.length, dtype=noise_level.dtype, device=noise_level.device) / self.length
|
||||||
|
# encoding = noise_level.unsqueeze(1) * torch.exp(
|
||||||
|
# -ln(1e4) * step.unsqueeze(0))
|
||||||
|
# encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1)
|
||||||
|
# return encoding
|
||||||
|
|
||||||
|
|
||||||
class PositionalEncoding(nn.Module):
|
class PositionalEncoding(nn.Module):
|
||||||
def __init__(self, n_channels):
|
def __init__(self, n_channels, max_len=10000):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_channels = n_channels
|
self.n_channels = n_channels
|
||||||
self.length = n_channels // 2
|
self.max_len = max_len
|
||||||
assert n_channels % 2 == 0
|
self.C = 5000
|
||||||
|
self.pe = torch.zeros(0, 0)
|
||||||
|
|
||||||
def forward(self, x, noise_level):
|
def forward(self, x, noise_level):
|
||||||
"""
|
if x.shape[2] > self.pe.shape[1]:
|
||||||
Shapes:
|
self.init_pe_matrix(x.shape[1] ,x.shape[2], x)
|
||||||
x: B x C x T
|
return x + noise_level[..., None, None] + self.pe[:, :x.size(2)].repeat(x.shape[0], 1, 1) / self.C
|
||||||
noise_level: B
|
|
||||||
"""
|
|
||||||
return (x + self.encoding(noise_level)[:, :, None])
|
|
||||||
|
|
||||||
def encoding(self, noise_level):
|
def init_pe_matrix(self, n_channels, max_len, x):
|
||||||
step = torch.arange(
|
pe = torch.zeros(max_len, n_channels)
|
||||||
self.length, dtype=noise_level.dtype, device=noise_level.device) / self.length
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||||
encoding = noise_level.unsqueeze(1) * torch.exp(
|
div_term = torch.pow(10000, torch.arange(0, n_channels, 2).float() / n_channels)
|
||||||
-ln(1e4) * step.unsqueeze(0))
|
|
||||||
encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1)
|
pe[:, 0::2] = torch.sin(position / div_term)
|
||||||
return encoding
|
pe[:, 1::2] = torch.cos(position / div_term)
|
||||||
|
self.pe = pe.transpose(0, 1).to(x)
|
||||||
|
|
||||||
|
|
||||||
class FiLM(nn.Module):
|
class FiLM(nn.Module):
|
||||||
def __init__(self, input_size, output_size):
|
def __init__(self, input_size, output_size):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoding = PositionalEncoding(input_size)
|
self.encoding = PositionalEncoding(input_size)
|
||||||
self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1)
|
self.input_conv = weight_norm(nn.Conv1d(input_size, input_size, 3, padding=1))
|
||||||
self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1)
|
self.output_conv = weight_norm(nn.Conv1d(input_size, output_size * 2, 3, padding=1))
|
||||||
self.ini_parameters()
|
self.ini_parameters()
|
||||||
|
|
||||||
def ini_parameters(self):
|
def ini_parameters(self):
|
||||||
|
@ -72,30 +96,30 @@ class UBlock(nn.Module):
|
||||||
assert len(dilation) == 4
|
assert len(dilation) == 4
|
||||||
|
|
||||||
self.factor = factor
|
self.factor = factor
|
||||||
self.block1 = Conv1d(input_size, hidden_size, 1)
|
self.block1 = weight_norm(Conv1d(input_size, hidden_size, 1))
|
||||||
self.block2 = nn.ModuleList([
|
self.block2 = nn.ModuleList([
|
||||||
Conv1d(input_size,
|
weight_norm(Conv1d(input_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
3,
|
3,
|
||||||
dilation=dilation[0],
|
dilation=dilation[0],
|
||||||
padding=dilation[0]),
|
padding=dilation[0])),
|
||||||
Conv1d(hidden_size,
|
weight_norm(Conv1d(hidden_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
3,
|
3,
|
||||||
dilation=dilation[1],
|
dilation=dilation[1],
|
||||||
padding=dilation[1])
|
padding=dilation[1]))
|
||||||
])
|
])
|
||||||
self.block3 = nn.ModuleList([
|
self.block3 = nn.ModuleList([
|
||||||
Conv1d(hidden_size,
|
weight_norm(Conv1d(hidden_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
3,
|
3,
|
||||||
dilation=dilation[2],
|
dilation=dilation[2],
|
||||||
padding=dilation[2]),
|
padding=dilation[2])),
|
||||||
Conv1d(hidden_size,
|
weight_norm(Conv1d(hidden_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
3,
|
3,
|
||||||
dilation=dilation[3],
|
dilation=dilation[3],
|
||||||
padding=dilation[3])
|
padding=dilation[3]))
|
||||||
])
|
])
|
||||||
|
|
||||||
def forward(self, x, shift, scale):
|
def forward(self, x, shift, scale):
|
||||||
|
@ -129,11 +153,11 @@ class DBlock(nn.Module):
|
||||||
def __init__(self, input_size, hidden_size, factor):
|
def __init__(self, input_size, hidden_size, factor):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.factor = factor
|
self.factor = factor
|
||||||
self.residual_dense = Conv1d(input_size, hidden_size, 1)
|
self.residual_dense = weight_norm(Conv1d(input_size, hidden_size, 1))
|
||||||
self.conv = nn.ModuleList([
|
self.conv = nn.ModuleList([
|
||||||
Conv1d(input_size, hidden_size, 3, dilation=1, padding=1),
|
weight_norm(Conv1d(input_size, hidden_size, 3, dilation=1, padding=1)),
|
||||||
Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2),
|
weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2)),
|
||||||
Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4),
|
weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4)),
|
||||||
])
|
])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
Loading…
Reference in New Issue