Separate backward pass for stop-token prediction

pull/10/head
Eren Golge 2018-05-11 08:38:07 -07:00
parent 5a5b9263e2
commit dc3f909ddc
1 changed files with 23 additions and 6 deletions

View File

@ -59,7 +59,7 @@ LOG_DIR = OUT_PATH
tb = SummaryWriter(LOG_DIR)
def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, epoch):
model = model.train()
epoch_time = 0
avg_linear_loss = 0
@ -88,10 +88,15 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
# setup lr
current_lr = lr_decay(c.lr, current_step, c.warmup_steps)
for params_group in optimizer.param_groups:
params_group['lr'] = current_lr
for params_group in optimizer_st.param_groups:
params_group['lr'] = current_lr
optimizer.zero_grad()
optimizer_st.zero_grad()
# dispatch data to GPU
if use_cuda:
@ -112,16 +117,25 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_input[:, :, :n_priority_freq],
mel_lengths)
loss = mel_loss + linear_loss + stop_loss
loss = mel_loss + linear_loss
# backpass and check the grad norm
loss.backward()
# backpass and check the grad norm for spec losses
loss.backward(retain_graph=True)
grad_norm, skip_flag = check_update(model, 0.5, 100)
if skip_flag:
optimizer.zero_grad()
print(" | > Iteration skipped!!")
continue
optimizer.step()
# backpass and check the grad norm for stop loss
stop_loss.backward()
grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet, 0.5, 100)
if skip_flag:
optimizer_st.zero_grad()
print(" | > Iteration skipped fro stopnet!!")
continue
optimizer_st.step()
step_time = time.time() - start_time
epoch_time += step_time
@ -131,7 +145,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
('linear_loss', linear_loss.item()),
('mel_loss', mel_loss.item()),
('stop_loss', stop_loss.item()),
('grad_norm', grad_norm.item())])
('grad_norm', grad_norm.item()),
('grad_norm_st', grad_norm_st.item())])
avg_linear_loss += linear_loss.item()
avg_mel_loss += mel_loss.item()
avg_stop_loss += stop_loss.item()
@ -144,6 +159,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
current_step)
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
tb.add_scalar('Params/GradNormSt', grad_norm_st, current_step)
tb.add_scalar('Time/StepTime', step_time, current_step)
if current_step % c.save_step == 0:
@ -345,6 +361,7 @@ def main(args):
c.r)
optimizer = optim.Adam(model.parameters(), lr=c.lr)
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
criterion = L1LossMasked()
criterion_st = nn.BCELoss()
@ -378,7 +395,7 @@ def main(args):
for epoch in range(0, c.epochs):
train_loss, current_step = train(
model, criterion, criterion_st, train_loader, optimizer, epoch)
model, criterion, criterion_st, train_loader, optimizer, optimizer_st, epoch)
val_loss = evaluate(model, criterion, criterion_st, val_loader, current_step)
best_loss = save_best_model(model, optimizer, val_loss,
best_loss, OUT_PATH,