weight norm and torch based amp training for wavegrad

pull/10/head
erogol 2020-10-27 12:06:57 +01:00
parent a3213762ae
commit 14c2381207
3 changed files with 97 additions and 83 deletions

View File

@ -7,10 +7,7 @@ import traceback
import torch import torch
# DISTRIBUTED # DISTRIBUTED
try: from torch.nn.parallel import DistributedDataParallel as DDP_th
from apex.parallel import DistributedDataParallel as DDP_apex
except:
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
@ -82,7 +79,7 @@ def format_test_data(data):
def train(model, criterion, optimizer, def train(model, criterion, optimizer,
scheduler, ap, global_step, epoch, amp): scheduler, ap, global_step, epoch):
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
model.train() model.train()
epoch_time = 0 epoch_time = 0
@ -104,6 +101,7 @@ def train(model, criterion, optimizer,
model.compute_noise_level(noise_schedule['num_steps'], model.compute_noise_level(noise_schedule['num_steps'],
noise_schedule['min_val'], noise_schedule['min_val'],
noise_schedule['max_val']) noise_schedule['max_val'])
scaler = torch.cuda.amp.GradScaler()
for num_iter, data in enumerate(data_loader): for num_iter, data in enumerate(data_loader):
start_time = time.time() start_time = time.time()
@ -113,39 +111,46 @@ def train(model, criterion, optimizer,
global_step += 1 global_step += 1
# compute noisy input with torch.cuda.amp.autocast():
if hasattr(model, 'module'): # compute noisy input
noise, x_noisy, noise_scale = model.module.compute_y_n(x) if hasattr(model, 'module'):
else: noise, x_noisy, noise_scale = model.module.compute_y_n(x)
noise, x_noisy, noise_scale = model.compute_y_n(x) else:
noise, x_noisy, noise_scale = model.compute_y_n(x)
# forward pass # forward pass
noise_hat = model(x_noisy, m, noise_scale) noise_hat = model(x_noisy, m, noise_scale)
# compute losses # compute losses
loss = criterion(noise, noise_hat) loss = criterion(noise, noise_hat)
# if loss.item() > 100:
# breakpoint()
loss_wavegrad_dict = {'wavegrad_loss':loss} loss_wavegrad_dict = {'wavegrad_loss':loss}
# backward pass with loss scaling # check nan loss
if torch.isnan(loss).any():
raise RuntimeError(f'Detected NaN loss at step {self.step}.')
optimizer.zero_grad() optimizer.zero_grad()
if amp is not None: # schedule update
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
if c.clip_grad > 0:
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.clip_grad)
optimizer.step()
# schedule update
if scheduler is not None: if scheduler is not None:
scheduler.step() scheduler.step()
# backward pass with loss scaling
if c.mixed_precision:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.clip_grad)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.clip_grad)
optimizer.step()
# disconnect loss values # disconnect loss values
loss_dict = dict() loss_dict = dict()
for key, value in loss_wavegrad_dict.items(): for key, value in loss_wavegrad_dict.items():
@ -175,7 +180,7 @@ def train(model, criterion, optimizer,
'step_time': [step_time, 2], 'step_time': [step_time, 2],
'loader_time': [loader_time, 4], 'loader_time': [loader_time, 4],
"current_lr": current_lr, "current_lr": current_lr,
"grad_norm": grad_norm "grad_norm": grad_norm.item()
} }
c_logger.print_train_step(batch_n_iter, num_iter, global_step, c_logger.print_train_step(batch_n_iter, num_iter, global_step,
log_dict, loss_dict, keep_avg.avg_values) log_dict, loss_dict, keep_avg.avg_values)
@ -185,7 +190,7 @@ def train(model, criterion, optimizer,
if global_step % 10 == 0: if global_step % 10 == 0:
iter_stats = { iter_stats = {
"lr": current_lr, "lr": current_lr,
"grad_norm": grad_norm, "grad_norm": grad_norm.item(),
"step_time": step_time "step_time": step_time
} }
iter_stats.update(loss_dict) iter_stats.update(loss_dict)
@ -335,16 +340,6 @@ def main(args): # pylint: disable=redefined-outer-name
# setup optimizers # setup optimizers
optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0) optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)
# DISTRIBUTED
if c.apex_amp_level is not None:
# pylint: disable=import-outside-toplevel
from apex import amp
model.cuda()
# optimizer.cuda()
model, optimizer = amp.initialize(model, optimizer, opt_level=c.apex_amp_level)
else:
amp = None
# schedulers # schedulers
scheduler = None scheduler = None
if 'lr_scheduler' in c: if 'lr_scheduler' in c:
@ -374,10 +369,6 @@ def main(args): # pylint: disable=redefined-outer-name
model.load_state_dict(model_dict) model.load_state_dict(model_dict)
del model_dict del model_dict
# DISTRUBUTED
if amp and 'amp' in checkpoint:
amp.load_state_dict(checkpoint['amp'])
# reset lr if not countinuining training. # reset lr if not countinuining training.
for group in optimizer.param_groups: for group in optimizer.param_groups:
group['lr'] = c.lr group['lr'] = c.lr
@ -410,7 +401,7 @@ def main(args): # pylint: disable=redefined-outer-name
c_logger.print_epoch_start(epoch, c.epochs) c_logger.print_epoch_start(epoch, c.epochs)
_, global_step = train(model, criterion, optimizer, _, global_step = train(model, criterion, optimizer,
scheduler, ap, global_step, scheduler, ap, global_step,
epoch, amp) epoch)
eval_avg_loss_dict = evaluate(model, criterion, ap, eval_avg_loss_dict = evaluate(model, criterion, ap,
global_step, epoch) global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict) c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
@ -426,8 +417,7 @@ def main(args): # pylint: disable=redefined-outer-name
global_step, global_step,
epoch, epoch,
OUT_PATH, OUT_PATH,
model_losses=eval_avg_loss_dict, model_losses=eval_avg_loss_dict)
amp_state_dict=amp.state_dict() if amp else None)
if __name__ == '__main__': if __name__ == '__main__':
@ -481,8 +471,8 @@ if __name__ == '__main__':
_ = os.path.dirname(os.path.realpath(__file__)) _ = os.path.dirname(os.path.realpath(__file__))
# DISTRIBUTED # DISTRIBUTED
if c.apex_amp_level is not None: if c.mixed_precision:
print(" > apex AMP level: ", c.apex_amp_level) print(" > Mixed precision is enabled")
OUT_PATH = args.continue_path OUT_PATH = args.continue_path
if args.continue_path == '': if args.continue_path == '':

View File

@ -34,7 +34,7 @@
}, },
// DISTRIBUTED TRAINING // DISTRIBUTED TRAINING
"apex_amp_level": "O1", // APEX amp optimization level. "O1" is currently supported. "mixed_precision": true, // enable torch mixed precision training (true, false)
"distributed":{ "distributed":{
"backend": "nccl", "backend": "nccl",
"url": "tcp:\/\/localhost:54322" "url": "tcp:\/\/localhost:54322"
@ -98,7 +98,7 @@
// TENSORBOARD and LOGGING // TENSORBOARD and LOGGING
"print_step": 50, // Number of steps to log traning on console. "print_step": 50, // Number of steps to log traning on console.
"print_eval": false, // If True, it prints loss values for each step in eval run. "print_eval": false, // If True, it prints loss values for each step in eval run.
"save_step": 10000, // Number of training steps expected to plot training stats on TB and save model checkpoints. "save_step": 5000, // Number of training steps expected to plot training stats on TB and save model checkpoints.
"checkpoint": true, // If true, it saves checkpoints per "save_step" "checkpoint": true, // If true, it saves checkpoints per "save_step"
"tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. "tb_model_param_stats": true, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.

View File

@ -2,6 +2,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.utils import weight_norm
from math import log as ln from math import log as ln
@ -13,36 +14,59 @@ class Conv1d(nn.Conv1d):
nn.init.zeros_(self.bias) nn.init.zeros_(self.bias)
# class PositionalEncoding(nn.Module):
# def __init__(self, n_channels):
# super().__init__()
# self.n_channels = n_channels
# self.length = n_channels // 2
# assert n_channels % 2 == 0
# def forward(self, x, noise_level):
# """
# Shapes:
# x: B x C x T
# noise_level: B
# """
# return (x + self.encoding(noise_level)[:, :, None])
# def encoding(self, noise_level):
# step = torch.arange(
# self.length, dtype=noise_level.dtype, device=noise_level.device) / self.length
# encoding = noise_level.unsqueeze(1) * torch.exp(
# -ln(1e4) * step.unsqueeze(0))
# encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1)
# return encoding
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
def __init__(self, n_channels): def __init__(self, n_channels, max_len=10000):
super().__init__() super().__init__()
self.n_channels = n_channels self.n_channels = n_channels
self.length = n_channels // 2 self.max_len = max_len
assert n_channels % 2 == 0 self.C = 5000
self.pe = torch.zeros(0, 0)
def forward(self, x, noise_level): def forward(self, x, noise_level):
""" if x.shape[2] > self.pe.shape[1]:
Shapes: self.init_pe_matrix(x.shape[1] ,x.shape[2], x)
x: B x C x T return x + noise_level[..., None, None] + self.pe[:, :x.size(2)].repeat(x.shape[0], 1, 1) / self.C
noise_level: B
"""
return (x + self.encoding(noise_level)[:, :, None])
def encoding(self, noise_level): def init_pe_matrix(self, n_channels, max_len, x):
step = torch.arange( pe = torch.zeros(max_len, n_channels)
self.length, dtype=noise_level.dtype, device=noise_level.device) / self.length position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
encoding = noise_level.unsqueeze(1) * torch.exp( div_term = torch.pow(10000, torch.arange(0, n_channels, 2).float() / n_channels)
-ln(1e4) * step.unsqueeze(0))
encoding = torch.cat([torch.sin(encoding), torch.cos(encoding)], dim=-1) pe[:, 0::2] = torch.sin(position / div_term)
return encoding pe[:, 1::2] = torch.cos(position / div_term)
self.pe = pe.transpose(0, 1).to(x)
class FiLM(nn.Module): class FiLM(nn.Module):
def __init__(self, input_size, output_size): def __init__(self, input_size, output_size):
super().__init__() super().__init__()
self.encoding = PositionalEncoding(input_size) self.encoding = PositionalEncoding(input_size)
self.input_conv = nn.Conv1d(input_size, input_size, 3, padding=1) self.input_conv = weight_norm(nn.Conv1d(input_size, input_size, 3, padding=1))
self.output_conv = nn.Conv1d(input_size, output_size * 2, 3, padding=1) self.output_conv = weight_norm(nn.Conv1d(input_size, output_size * 2, 3, padding=1))
self.ini_parameters() self.ini_parameters()
def ini_parameters(self): def ini_parameters(self):
@ -72,30 +96,30 @@ class UBlock(nn.Module):
assert len(dilation) == 4 assert len(dilation) == 4
self.factor = factor self.factor = factor
self.block1 = Conv1d(input_size, hidden_size, 1) self.block1 = weight_norm(Conv1d(input_size, hidden_size, 1))
self.block2 = nn.ModuleList([ self.block2 = nn.ModuleList([
Conv1d(input_size, weight_norm(Conv1d(input_size,
hidden_size, hidden_size,
3, 3,
dilation=dilation[0], dilation=dilation[0],
padding=dilation[0]), padding=dilation[0])),
Conv1d(hidden_size, weight_norm(Conv1d(hidden_size,
hidden_size, hidden_size,
3, 3,
dilation=dilation[1], dilation=dilation[1],
padding=dilation[1]) padding=dilation[1]))
]) ])
self.block3 = nn.ModuleList([ self.block3 = nn.ModuleList([
Conv1d(hidden_size, weight_norm(Conv1d(hidden_size,
hidden_size, hidden_size,
3, 3,
dilation=dilation[2], dilation=dilation[2],
padding=dilation[2]), padding=dilation[2])),
Conv1d(hidden_size, weight_norm(Conv1d(hidden_size,
hidden_size, hidden_size,
3, 3,
dilation=dilation[3], dilation=dilation[3],
padding=dilation[3]) padding=dilation[3]))
]) ])
def forward(self, x, shift, scale): def forward(self, x, shift, scale):
@ -129,11 +153,11 @@ class DBlock(nn.Module):
def __init__(self, input_size, hidden_size, factor): def __init__(self, input_size, hidden_size, factor):
super().__init__() super().__init__()
self.factor = factor self.factor = factor
self.residual_dense = Conv1d(input_size, hidden_size, 1) self.residual_dense = weight_norm(Conv1d(input_size, hidden_size, 1))
self.conv = nn.ModuleList([ self.conv = nn.ModuleList([
Conv1d(input_size, hidden_size, 3, dilation=1, padding=1), weight_norm(Conv1d(input_size, hidden_size, 3, dilation=1, padding=1)),
Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2), weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2)),
Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4), weight_norm(Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4)),
]) ])
def forward(self, x): def forward(self, x):