sync torch calls before logging training results

pull/10/head
erogol 2020-12-07 11:30:19 +01:00
parent 7505c0ba27
commit 482e725752
1 changed files with 4 additions and 1 deletions

View File

@ -186,7 +186,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
# forward pass model
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_ids)
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c)
# compute loss
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
@ -273,6 +273,9 @@ def train(data_loader, model, criterion, optimizer, scheduler,
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
model_loss=loss_dict['loss'])
# wait all kernels to be completed
torch.cuda.synchronize()
# Diagnostic visualizations
# direct pass on model for spec predictions
target_speaker = None if speaker_c is None else speaker_c[:1]