weight norm and torch based amp training for wavegrad

pull/10/head
erogol 2020-10-27 12:06:57 +01:00
parent a3213762ae
commit 14c2381207
3 changed files with 97 additions and 83 deletions

View File

@ -7,10 +7,7 @@ import traceback
import torch
# DISTRIBUTED
try:
from apex.parallel import DistributedDataParallel as DDP_apex
except:
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
@ -82,7 +79,7 @@ def format_test_data(data):
def train(model, criterion, optimizer,
scheduler, ap, global_step, epoch, amp):
scheduler, ap, global_step, epoch):
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
model.train()
epoch_time = 0
@ -104,6 +101,7 @@ def train(model, criterion, optimizer,
model.compute_noise_level(noise_schedule['num_steps'],
noise_schedule['min_val'],
noise_schedule['max_val'])
scaler = torch.cuda.amp.GradScaler()
for num_iter, data in enumerate(data_loader):
start_time = time.time()
@ -113,39 +111,46 @@ def train(model, criterion, optimizer,
global_step += 1
# compute noisy input
if hasattr(model, 'module'):
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
else:
noise, x_noisy, noise_scale = model.compute_y_n(x)
with torch.cuda.amp.autocast():
# compute noisy input
if hasattr(model, 'module'):
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
else:
noise, x_noisy, noise_scale = model.compute_y_n(x)
# forward pass
noise_hat = model(x_noisy, m, noise_scale)
# forward pass
noise_hat = model(x_noisy, m, noise_scale)
# compute losses
loss = criterion(noise, noise_hat)
# if loss.item() > 100:
# breakpoint()
# compute losses
loss = criterion(noise, noise_hat)
loss_wavegrad_dict = {'wavegrad_loss':loss}
# backward pass with loss scaling
# check nan loss
if torch.isnan(loss).any():
raise RuntimeError(f'Detected NaN loss at step {self.step}.')
optimizer.zero_grad()
if amp is not None:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
if c.clip_grad > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.clip_grad)
optimizer.step()
# schedule update
# schedule update
if scheduler is not None:
scheduler.step()
# backward pass with loss scaling
if c.mixed_precision:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.clip_grad)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.clip_grad)
optimizer.step()
# disconnect loss values
loss_dict = dict()
for key, value in loss_wavegrad_dict.items():
@ -175,7 +180,7 @@ def train(model, criterion, optimizer,
'step_time': [step_time, 2],
'loader_time': [loader_time, 4],
"current_lr": current_lr,
"grad_norm": grad_norm
"grad_norm": grad_norm.item()
}
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
log_dict, loss_dict, keep_avg.avg_values)
@ -185,7 +190,7 @@ def train(model, criterion, optimizer,
if global_step % 10 == 0:
iter_stats = {
"lr": current_lr,
"grad_norm": grad_norm,
"grad_norm": grad_norm.item(),
"step_time": step_time
}
iter_stats.update(loss_dict)
@ -335,16 +340,6 @@ def main(args): # pylint: disable=redefined-outer-name
# setup optimizers
optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)
# DISTRIBUTED
if c.apex_amp_level is not None:
# pylint: disable=import-outside-toplevel
from apex import amp
model.cuda()
# optimizer.cuda()
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
else:
amp = None
# schedulers
scheduler = None
if 'lr_scheduler' in c:
@ -374,10 +369,6 @@ def main(args): # pylint: disable=redefined-outer-name
model.load_state_dict(model_dict)
del model_dict
# DISTRUBUTED
if amp and 'amp' in checkpoint:
amp.load_state_dict(checkpoint['amp'])
# reset lr if not countinuining training.
for group in optimizer.param_groups:
group['lr'] = c.lr
@ -410,7 +401,7 @@ def main(args): # pylint: disable=redefined-outer-name
c_logger.print_epoch_start(epoch, c.epochs)
_, global_step = train(model, criterion, optimizer,
scheduler, ap, global_step,
epoch, amp)
epoch)
eval_avg_loss_dict = evaluate(model, criterion, ap,
global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
@ -426,8 +417,7 @@ def main(args): # pylint: disable=redefined-outer-name
global_step,
epoch,
OUT_PATH,
model_losses=eval_avg_loss_dict,
amp_state_dict=amp.state_dict() if amp else None)
model_losses=eval_avg_loss_dict)
if __name__ == '__main__':
@ -481,8 +471,8 @@ if __name__ == '__main__':
_ = os.path.dirname(os.path.realpath(__file__))
# DISTRIBUTED
if c.apex_amp_level is not None:
print(" > apex AMP level: ", c.apex_amp_level)
if c.mixed_precision:
print(" > Mixed precision is enabled")
OUT_PATH = args.continue_path
if args.continue_path == '':

View File

@ -34,7 +34,7 @@
},
// DISTRIBUTED TRAINING
"apex_amp_level": "O1", // APEX amp optimization level. "O1" is currently supported.
"mixed_precision": true, // enable torch mixed precision training (true, false)
"distributed":{
"backend": "nccl",
"url": "tcp:\/\/localhost:54322"
@ -98,7 +98,7 @@
// TENSORBOARD and LOGGING
"print_step": 50, // Number of steps to log traning on console.
"print_eval": false, // If True, it prints loss values for each step in eval run.
"save_step": 10000, // Number of training steps expected to plot training stats on TB and save model checkpoints.
"save_step": 5000, // Number of training steps expected to plot training stats on TB and save model checkpoints.
"checkpoint": true, // If true, it saves checkpoints per "save_step"
"tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.

View File

@ -2,6 +2,7 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm
from math import log as ln
@ -13,36 +14,59 @@ class Conv1d(nn.Conv1d):
nn.init.zeros_(self.bias)
# class PositionalEncoding(nn.Module):
# def __init__(self, n_channels):
# super().__init__()
# self.n_channels = n_channels
# self.length = n_channels // 2
# assert n_channels % 2 == 0
# def forward(self, x, noise_level):
# """
# Shapes:
# x: B x C x T
# noise_level: B
# """
# return (x + self.encoding(noise_level)[:, :, None])
# def encoding(self, noise_level):
# step = torch.arange(
# self.length, dtype=noise_level.dtype, device=noise_level.device) / self.length
# encoding = noise_level.unsqueeze(1) * torch.exp(
# -ln(1e4) * step.unsqueeze(0))
# encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1)
# return encoding
class PositionalEncoding(nn.Module):
def __init__(self, n_channels):
def __init__(self, n_channels, max_len=10000):
super().__init__()
self.n_channels = n_channels
self.length = n_channels // 2
assert n_channels % 2 == 0
self.max_len = max_len
self.C = 5000
self.pe = torch.zeros(0, 0)
def forward(self, x, noise_level):
"""
Shapes:
x: B x C x T
noise_level: B
"""
return (x + self.encoding(noise_level)[:, :, None])
if x.shape[2] > self.pe.shape[1]:
self.init_pe_matrix(x.shape[1] ,x.shape[2], x)
return x + noise_level[..., None, None] + self.pe[:, :x.size(2)].repeat(x.shape[0], 1, 1) / self.C
def encoding(self, noise_level):
step = torch.arange(
self.length, dtype=noise_level.dtype, device=noise_level.device) / self.length
encoding = noise_level.unsqueeze(1) * torch.exp(
-ln(1e4) * step.unsqueeze(0))
encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1)
return encoding
def init_pe_matrix(self, n_channels, max_len, x):
pe = torch.zeros(max_len, n_channels)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.pow(10000, torch.arange(0, n_channels, 2).float() / n_channels)
pe[:, 0::2] = torch.sin(position / div_term)
pe[:, 1::2] = torch.cos(position / div_term)
self.pe = pe.transpose(0, 1).to(x)
class FiLM(nn.Module):
def __init__(self, input_size, output_size):
super().__init__()
self.encoding = PositionalEncoding(input_size)
self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1)
self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1)
self.input_conv = weight_norm(nn.Conv1d(input_size, input_size, 3, padding=1))
self.output_conv = weight_norm(nn.Conv1d(input_size, output_size * 2, 3, padding=1))
self.ini_parameters()
def ini_parameters(self):
@ -72,30 +96,30 @@ class UBlock(nn.Module):
assert len(dilation) == 4
self.factor = factor
self.block1 = Conv1d(input_size, hidden_size, 1)
self.block1 = weight_norm(Conv1d(input_size, hidden_size, 1))
self.block2 = nn.ModuleList([
Conv1d(input_size,
weight_norm(Conv1d(input_size,
hidden_size,
3,
dilation=dilation[0],
padding=dilation[0]),
Conv1d(hidden_size,
padding=dilation[0])),
weight_norm(Conv1d(hidden_size,
hidden_size,
3,
dilation=dilation[1],
padding=dilation[1])
padding=dilation[1]))
])
self.block3 = nn.ModuleList([
Conv1d(hidden_size,
weight_norm(Conv1d(hidden_size,
hidden_size,
3,
dilation=dilation[2],
padding=dilation[2]),
Conv1d(hidden_size,
padding=dilation[2])),
weight_norm(Conv1d(hidden_size,
hidden_size,
3,
dilation=dilation[3],
padding=dilation[3])
padding=dilation[3]))
])
def forward(self, x, shift, scale):
@ -129,11 +153,11 @@ class DBlock(nn.Module):
def __init__(self, input_size, hidden_size, factor):
super().__init__()
self.factor = factor
self.residual_dense = Conv1d(input_size, hidden_size, 1)
self.residual_dense = weight_norm(Conv1d(input_size, hidden_size, 1))
self.conv = nn.ModuleList([
Conv1d(input_size, hidden_size, 3, dilation=1, padding=1),
Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2),
Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4),
weight_norm(Conv1d(input_size, hidden_size, 3, dilation=1, padding=1)),
weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2)),
weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4)),
])
def forward(self, x):