mirror of https://github.com/coqui-ai/TTS.git
conditional stopnet and separate stopnet training
parent
66453b81d3
commit
82d873b35a
56
train.py
56
train.py
|
@ -109,7 +109,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
if c.lr_decay:
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
optimizer_st.zero_grad()
|
||||
if optimizer_st: optimizer_st.zero_grad();
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
|
@ -125,7 +125,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
text_input, text_lengths, mel_input)
|
||||
|
||||
# loss computation
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
|
||||
if c.loss_masking:
|
||||
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
||||
if c.model == "Tacotron":
|
||||
|
@ -139,18 +139,26 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
else:
|
||||
postnet_loss = criterion(postnet_output, mel_input)
|
||||
loss = decoder_loss + postnet_loss
|
||||
if not c.separate_stopnet and c.stopnet:
|
||||
loss += stop_loss
|
||||
|
||||
# backpass and check the grad norm for spec losses
|
||||
loss.backward(retain_graph=True)
|
||||
if c.separate_stopnet:
|
||||
loss.backward(retain_graph=True)
|
||||
else:
|
||||
loss.backward()
|
||||
optimizer, current_lr = weight_decay(optimizer, c.wd)
|
||||
grad_norm, _ = check_update(model, c.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
# backpass and check the grad norm for stop loss
|
||||
stop_loss.backward()
|
||||
optimizer_st, _ = weight_decay(optimizer_st, c.wd)
|
||||
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
||||
optimizer_st.step()
|
||||
if c.separate_stopnet:
|
||||
stop_loss.backward()
|
||||
optimizer_st, _ = weight_decay(optimizer_st, c.wd)
|
||||
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
||||
optimizer_st.step()
|
||||
else:
|
||||
grad_norm_st = 0
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
@ -175,7 +183,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
if args.rank == 0:
|
||||
avg_postnet_loss += float(postnet_loss.item())
|
||||
avg_decoder_loss += float(decoder_loss.item())
|
||||
avg_stop_loss += stop_loss.item()
|
||||
avg_stop_loss += stop_loss if type(stop_loss) is float else float(stop_loss.item())
|
||||
avg_step_time += step_time
|
||||
|
||||
# Plot Training Iter Stats
|
||||
|
@ -290,7 +298,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
|||
model.forward(text_input, text_lengths, mel_input)
|
||||
|
||||
# loss computation
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
|
||||
if c.loss_masking:
|
||||
decoder_loss = criterion(decoder_output, mel_input, mel_lengths)
|
||||
if c.model == "Tacotron":
|
||||
|
@ -395,14 +403,17 @@ def main(args):
|
|||
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
||||
optimizer_st = optim.Adam(
|
||||
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
|
||||
if c.stopnet and c.separate_stopnet:
|
||||
optimizer_st = optim.Adam(
|
||||
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
|
||||
else:
|
||||
optimizer_st = None
|
||||
|
||||
if c.loss_masking:
|
||||
criterion = L1LossMasked() if c.model == "Tacotron" else MSELossMasked()
|
||||
else:
|
||||
criterion = nn.L1Loss() if c.model == "Tacotron" else nn.MSELoss()
|
||||
criterion_st = nn.BCEWithLogitsLoss()
|
||||
criterion_st = nn.BCEWithLogitsLoss() if c.stopnet else None
|
||||
|
||||
if args.restore_path:
|
||||
checkpoint = torch.load(args.restore_path)
|
||||
|
@ -420,23 +431,19 @@ def main(args):
|
|||
model_dict = set_init_dict(model_dict, checkpoint, c)
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
if use_cuda:
|
||||
model = model.cuda()
|
||||
criterion.cuda()
|
||||
criterion_st.cuda()
|
||||
for group in optimizer.param_groups:
|
||||
group['lr'] = c.lr
|
||||
print(
|
||||
" > Model restored from step %d" % checkpoint['step'], flush=True)
|
||||
start_epoch = checkpoint['epoch']
|
||||
# best_loss = checkpoint['postnet_loss']
|
||||
args.restore_step = checkpoint['step']
|
||||
else:
|
||||
args.restore_step = 0
|
||||
if use_cuda:
|
||||
model = model.cuda()
|
||||
criterion.cuda()
|
||||
criterion_st.cuda()
|
||||
|
||||
if use_cuda:
|
||||
model = model.cuda()
|
||||
criterion.cuda()
|
||||
if criterion_st: criterion_st.cuda();
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
|
@ -457,9 +464,10 @@ def main(args):
|
|||
best_loss = float('inf')
|
||||
|
||||
for epoch in range(0, c.epochs):
|
||||
train_loss, current_step = train(model, criterion, criterion_st,
|
||||
optimizer, optimizer_st, scheduler,
|
||||
ap, epoch)
|
||||
# train_loss, current_step = train(model, criterion, criterion_st,
|
||||
# optimizer, optimizer_st, scheduler,
|
||||
# ap, epoch)
|
||||
train_loss, current_step = 0, 0
|
||||
val_loss = evaluate(model, criterion, criterion_st, ap, current_step, epoch)
|
||||
print(
|
||||
" | > Training Loss: {:.5f} Validation Loss: {:.5f}".format(
|
||||
|
|
Loading…
Reference in New Issue