conditional stopnet and separate stopnet training

pull/10/head
Eren Golge 2019-05-14 17:49:20 +02:00
parent 66453b81d3
commit 82d873b35a
1 changed files with 32 additions and 24 deletions

View File

@ -109,7 +109,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
if c.lr_decay:
scheduler.step()
optimizer.zero_grad()
optimizer_st.zero_grad()
if optimizer_st: optimizer_st.zero_grad();
# dispatch data to GPU
if use_cuda:
@ -125,7 +125,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
text_input, text_lengths, mel_input)
# loss computation
stop_loss = criterion_st(stop_tokens, stop_targets)
stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
if c.loss_masking:
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
if c.model == "Tacotron":
@ -139,18 +139,26 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
else:
postnet_loss = criterion(postnet_output, mel_input)
loss = decoder_loss + postnet_loss
if not c.separate_stopnet and c.stopnet:
loss += stop_loss
# backpass and check the grad norm for spec losses
loss.backward(retain_graph=True)
if c.separate_stopnet:
loss.backward(retain_graph=True)
else:
loss.backward()
optimizer, current_lr = weight_decay(optimizer, c.wd)
grad_norm, _ = check_update(model, c.grad_clip)
optimizer.step()
# backpass and check the grad norm for stop loss
stop_loss.backward()
optimizer_st, _ = weight_decay(optimizer_st, c.wd)
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
optimizer_st.step()
if c.separate_stopnet:
stop_loss.backward()
optimizer_st, _ = weight_decay(optimizer_st, c.wd)
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
optimizer_st.step()
else:
grad_norm_st = 0
step_time = time.time() - start_time
epoch_time += step_time
@ -175,7 +183,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
if args.rank == 0:
avg_postnet_loss += float(postnet_loss.item())
avg_decoder_loss += float(decoder_loss.item())
avg_stop_loss += stop_loss.item()
avg_stop_loss += stop_loss if type(stop_loss) is float else float(stop_loss.item())
avg_step_time += step_time
# Plot Training Iter Stats
@ -290,7 +298,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
model.forward(text_input, text_lengths, mel_input)
# loss computation
stop_loss = criterion_st(stop_tokens, stop_targets)
stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
if c.loss_masking:
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
if c.model == "Tacotron":
@ -395,14 +403,17 @@ def main(args):
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
optimizer_st = optim.Adam(
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
if c.stopnet and c.separate_stopnet:
optimizer_st = optim.Adam(
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
else:
optimizer_st = None
if c.loss_masking:
criterion = L1LossMasked() if c.model == "Tacotron" else MSELossMasked()
else:
criterion = nn.L1Loss() if c.model == "Tacotron" else nn.MSELoss()
criterion_st = nn.BCEWithLogitsLoss()
criterion_st = nn.BCEWithLogitsLoss() if c.stopnet else None
if args.restore_path:
checkpoint = torch.load(args.restore_path)
@ -420,23 +431,19 @@ def main(args):
model_dict = set_init_dict(model_dict, checkpoint, c)
model.load_state_dict(model_dict)
del model_dict
if use_cuda:
model = model.cuda()
criterion.cuda()
criterion_st.cuda()
for group in optimizer.param_groups:
group['lr'] = c.lr
print(
" > Model restored from step %d" % checkpoint['step'], flush=True)
start_epoch = checkpoint['epoch']
# best_loss = checkpoint['postnet_loss']
args.restore_step = checkpoint['step']
else:
args.restore_step = 0
if use_cuda:
model = model.cuda()
criterion.cuda()
criterion_st.cuda()
if use_cuda:
model = model.cuda()
criterion.cuda()
if criterion_st: criterion_st.cuda();
# DISTRUBUTED
if num_gpus > 1:
@ -457,9 +464,10 @@ def main(args):
best_loss = float('inf')
for epoch in range(0, c.epochs):
train_loss, current_step = train(model, criterion, criterion_st,
optimizer, optimizer_st, scheduler,
ap, epoch)
# train_loss, current_step = train(model, criterion, criterion_st,
# optimizer, optimizer_st, scheduler,
# ap, epoch)
train_loss, current_step = 0, 0
val_loss = evaluate(model, criterion, criterion_st, ap, current_step, epoch)
print(
" | > Training Loss: {:.5f} Validation Loss: {:.5f}".format(