Use step wise LR scheduler + adapt train.py for passing squence mask directly

pull/10/head
Eren 2018-08-10 17:48:19 +02:00
parent e0bce1d2d1
commit 96e2e3c776
1 changed files with 40 additions and 48 deletions

View File

@ -16,11 +16,12 @@ from tensorboardX import SummaryWriter
from utils.generic_utils import (
synthesis, remove_experiment_folder, create_experiment_folder,
save_checkpoint, save_best_model, load_config, lr_decay, count_parameters,
check_update, get_commit_hash)
check_update, get_commit_hash, sequence_mask)
from utils.visual import plot_alignment, plot_spectrogram
from models.tacotron import Tacotron
from layers.losses import L1LossMasked
from utils.audio import AudioProcessor
from torch.optim.lr_scheduler import StepLR
torch.manual_seed(1)
torch.set_num_threads(4)
@ -28,7 +29,7 @@ use_cuda = torch.cuda.is_available()
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
ap, epoch):
scheduler, ap, epoch):
model = model.train()
epoch_time = 0
avg_linear_loss = 0
@ -58,15 +59,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
epoch * len(data_loader) + 1
# setup lr
current_lr = lr_decay(c.lr, current_step, c.warmup_steps)
current_lr_st = lr_decay(c.lr, current_step, c.warmup_steps)
for params_group in optimizer.param_groups:
params_group['lr'] = current_lr
for params_group in optimizer_st.param_groups:
params_group['lr'] = current_lr_st
scheduler.step()
optimizer.zero_grad()
optimizer_st.zero_grad()
@ -79,9 +72,12 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
linear_input = linear_input.cuda()
stop_targets = stop_targets.cuda()
# compute mask for padding
mask = sequence_mask(text_lengths)
# forward pass
mel_output, linear_output, alignments, stop_tokens =\
model.forward(text_input, mel_input, text_lengths)
mel_output, linear_output, alignments, stop_tokens = torch.nn.parallel.data_parallel(
model, (text_input, mel_input, mask))
# loss computation
stop_loss = criterion_st(stop_tokens, stop_targets)
@ -94,7 +90,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# backpass and check the grad norm for spec losses
loss.backward(retain_graph=True)
grad_norm, skip_flag = check_update(model, 0.5, 100)
grad_norm, skip_flag = check_update(model, 1)
if skip_flag:
optimizer.zero_grad()
print(" | > Iteration skipped!!", flush=True)
@ -103,8 +99,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# backpass and check the grad norm for stop loss
stop_loss.backward()
grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet,
0.5, 100)
grad_norm_st, skip_flag = check_update(model.decoder.stopnet, 0.5)
if skip_flag:
optimizer_st.zero_grad()
print(" | | > Iteration skipped fro stopnet!!")
@ -115,18 +110,14 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
epoch_time += step_time
if current_step % c.print_step == 0:
print(" | | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
"GradNormST:{:.5f} StepTime:{:.2f}".format(num_iter,
batch_n_iter,
current_step,
loss.item(),
linear_loss.item(),
mel_loss.item(),
stop_loss.item(),
grad_norm.item(),
grad_norm_st.item(),
step_time), flush=True)
print(
" | | > Step:{}/{} GlobalStep:{} TotalLoss:{:.5f} LinearLoss:{:.5f} "
"MelLoss:{:.5f} StopLoss:{:.5f} GradNorm:{:.5f} "
"GradNormST:{:.5f} StepTime:{:.2f}".format(
num_iter, batch_n_iter, current_step, loss.item(),
linear_loss.item(), mel_loss.item(), stop_loss.item(),
grad_norm.item(), grad_norm_st.item(), step_time),
flush=True)
avg_linear_loss += linear_loss.item()
avg_mel_loss += mel_loss.item()
@ -184,16 +175,14 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
avg_step_time /= (num_iter + 1)
# print epoch stats
print(" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
"AvgStepTime:{:.2f}".format(current_step,
avg_total_loss,
avg_linear_loss,
avg_mel_loss,
avg_stop_loss,
epoch_time,
avg_step_time), flush=True)
print(
" | | > EPOCH END -- GlobalStep:{} AvgTotalLoss:{:.5f} "
"AvgLinearLoss:{:.5f} AvgMelLoss:{:.5f} "
"AvgStopLoss:{:.5f} EpochTime:{:.2f} "
"AvgStepTime:{:.2f}".format(current_step, avg_total_loss,
avg_linear_loss, avg_mel_loss,
avg_stop_loss, epoch_time, avg_step_time),
flush=True)
# Plot Training Epoch Stats
tb.add_scalar('TrainEpochLoss/TotalLoss', avg_total_loss, current_step)
@ -266,8 +255,10 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
if num_iter % c.print_step == 0:
print(
" | | > TotalLoss: {:.5f} LinearLoss: {:.5f} MelLoss:{:.5f} "
"StopLoss: {:.5f} ".format(loss.item(), linear_loss.item(),
mel_loss.item(), stop_loss.item()),
"StopLoss: {:.5f} ".format(loss.item(),
linear_loss.item(),
mel_loss.item(),
stop_loss.item()),
flush=True)
avg_linear_loss += linear_loss.item()
@ -322,11 +313,11 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
ap.griffin_lim_iters = 60
for idx, test_sentence in enumerate(test_sentences):
try:
wav, linear_spec, alignments = synthesis(model, ap, test_sentence,
use_cuda, c.text_cleaner)
wav_name = 'TestSentences/{}'.format(idx)
tb.add_audio(
wav_name, wav, current_step, sample_rate=c.sample_rate)
wav, linear_spec, alignments = synthesis(model, ap, test_sentence,
use_cuda, c.text_cleaner)
wav_name = 'TestSentences/{}'.format(idx)
tb.add_audio(
wav_name, wav, current_step, sample_rate=c.sample_rate)
except:
print(" !! Error as creating Test Sentence -", idx)
pass
@ -405,7 +396,7 @@ def main(args):
checkpoint = torch.load(args.restore_path)
model.load_state_dict(checkpoint['model'])
if use_cuda:
model = nn.DataParallel(model.cuda())
model = model.cuda()
criterion.cuda()
criterion_st.cuda()
optimizer.load_state_dict(checkpoint['optimizer'])
@ -423,10 +414,11 @@ def main(args):
args.restore_step = 0
print("\n > Starting a new training", flush=True)
if use_cuda:
model = nn.DataParallel(model.cuda())
model = model.cuda()
criterion.cuda()
criterion_st.cuda()
scheduler = StepLR(optimizer, step_size=c.decay_step, gamma=c.lr_decay)
num_params = count_parameters(model)
print(" | > Model has {} parameters".format(num_params), flush=True)
@ -439,7 +431,7 @@ def main(args):
for epoch in range(0, c.epochs):
train_loss, current_step = train(model, criterion, criterion_st,
train_loader, optimizer, optimizer_st,
ap, epoch)
scheduler, ap, epoch)
val_loss = evaluate(model, criterion, criterion_st, val_loader, ap,
current_step)
print(