mirror of https://github.com/coqui-ai/TTS.git
Separate backward pass for stop-token prediction
parent
5a5b9263e2
commit
dc3f909ddc
29
train.py
29
train.py
|
@ -59,7 +59,7 @@ LOG_DIR = OUT_PATH
|
|||
tb = SummaryWriter(LOG_DIR)
|
||||
|
||||
|
||||
def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
|
||||
def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st, epoch):
|
||||
model = model.train()
|
||||
epoch_time = 0
|
||||
avg_linear_loss = 0
|
||||
|
@ -88,10 +88,15 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
|
|||
|
||||
# setup lr
|
||||
current_lr = lr_decay(c.lr, current_step, c.warmup_steps)
|
||||
|
||||
for params_group in optimizer.param_groups:
|
||||
params_group['lr'] = current_lr
|
||||
|
||||
for params_group in optimizer_st.param_groups:
|
||||
params_group['lr'] = current_lr
|
||||
|
||||
optimizer.zero_grad()
|
||||
optimizer_st.zero_grad()
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
|
@ -112,16 +117,25 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
|
|||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||
linear_input[:, :, :n_priority_freq],
|
||||
mel_lengths)
|
||||
loss = mel_loss + linear_loss + stop_loss
|
||||
loss = mel_loss + linear_loss
|
||||
|
||||
# backpass and check the grad norm
|
||||
loss.backward()
|
||||
# backpass and check the grad norm for spec losses
|
||||
loss.backward(retain_graph=True)
|
||||
grad_norm, skip_flag = check_update(model, 0.5, 100)
|
||||
if skip_flag:
|
||||
optimizer.zero_grad()
|
||||
print(" | > Iteration skipped!!")
|
||||
continue
|
||||
optimizer.step()
|
||||
|
||||
# backpass and check the grad norm for stop loss
|
||||
stop_loss.backward()
|
||||
grad_norm_st, skip_flag = check_update(model.module.decoder.stopnet, 0.5, 100)
|
||||
if skip_flag:
|
||||
optimizer_st.zero_grad()
|
||||
print(" | > Iteration skipped fro stopnet!!")
|
||||
continue
|
||||
optimizer_st.step()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
@ -131,7 +145,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
|
|||
('linear_loss', linear_loss.item()),
|
||||
('mel_loss', mel_loss.item()),
|
||||
('stop_loss', stop_loss.item()),
|
||||
('grad_norm', grad_norm.item())])
|
||||
('grad_norm', grad_norm.item()),
|
||||
('grad_norm_st', grad_norm_st.item())])
|
||||
avg_linear_loss += linear_loss.item()
|
||||
avg_mel_loss += mel_loss.item()
|
||||
avg_stop_loss += stop_loss.item()
|
||||
|
@ -144,6 +159,7 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
|
|||
tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
|
||||
current_step)
|
||||
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
|
||||
tb.add_scalar('Params/GradNormSt', grad_norm_st, current_step)
|
||||
tb.add_scalar('Time/StepTime', step_time, current_step)
|
||||
|
||||
if current_step % c.save_step == 0:
|
||||
|
@ -345,6 +361,7 @@ def main(args):
|
|||
c.r)
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
|
||||
|
||||
criterion = L1LossMasked()
|
||||
criterion_st = nn.BCELoss()
|
||||
|
@ -378,7 +395,7 @@ def main(args):
|
|||
|
||||
for epoch in range(0, c.epochs):
|
||||
train_loss, current_step = train(
|
||||
model, criterion, criterion_st, train_loader, optimizer, epoch)
|
||||
model, criterion, criterion_st, train_loader, optimizer, optimizer_st, epoch)
|
||||
val_loss = evaluate(model, criterion, criterion_st, val_loader, current_step)
|
||||
best_loss = save_best_model(model, optimizer, val_loss,
|
||||
best_loss, OUT_PATH,
|
||||
|
|
Loading…
Reference in New Issue