mirror of https://github.com/coqui-ai/TTS.git
sync torch calls before logging training results
parent
7505c0ba27
commit
482e725752
|
@ -186,7 +186,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
|||
# forward pass model
|
||||
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
||||
|
@ -273,6 +273,9 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
|||
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
|
||||
model_loss=loss_dict['loss'])
|
||||
|
||||
# wait all kernels to be completed
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Diagnostic visualizations
|
||||
# direct pass on model for spec predictions
|
||||
target_speaker = None if speaker_c is None else speaker_c[:1]
|
||||
|
|
Loading…
Reference in New Issue