mirror of https://github.com/coqui-ai/TTS.git
use decorater for torch.no_grad
parent
abf8ea4633
commit
2cec58320b
224
train.py
224
train.py
|
@ -327,6 +327,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||||
return keep_avg['avg_postnet_loss'], global_step
|
return keep_avg['avg_postnet_loss'], global_step
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
data_loader = setup_loader(ap, model.decoder.r, is_val=True)
|
data_loader = setup_loader(ap, model.decoder.r, is_val=True)
|
||||||
if c.use_speaker_embedding:
|
if c.use_speaker_embedding:
|
||||||
|
@ -346,125 +347,124 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
||||||
keep_avg.add_values(eval_values_dict)
|
keep_avg.add_values(eval_values_dict)
|
||||||
print("\n > Validation")
|
print("\n > Validation")
|
||||||
|
|
||||||
with torch.no_grad():
|
if data_loader is not None:
|
||||||
if data_loader is not None:
|
for num_iter, data in enumerate(data_loader):
|
||||||
for num_iter, data in enumerate(data_loader):
|
start_time = time.time()
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
# format data
|
# format data
|
||||||
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, _, _ = format_data(data)
|
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, _, _ = format_data(data)
|
||||||
assert mel_input.shape[1] % model.decoder.r == 0
|
assert mel_input.shape[1] % model.decoder.r == 0
|
||||||
|
|
||||||
# forward pass model
|
# forward pass model
|
||||||
if c.bidirectional_decoder:
|
if c.bidirectional_decoder:
|
||||||
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
||||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
||||||
else:
|
else:
|
||||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
text_input, text_lengths, mel_input, speaker_ids=speaker_ids)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
stop_loss = criterion_st(
|
stop_loss = criterion_st(
|
||||||
stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
|
stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
|
||||||
if c.loss_masking:
|
if c.loss_masking:
|
||||||
decoder_loss = criterion(decoder_output, mel_input,
|
decoder_loss = criterion(decoder_output, mel_input,
|
||||||
mel_lengths)
|
mel_lengths)
|
||||||
if c.model in ["Tacotron", "TacotronGST"]:
|
|
||||||
postnet_loss = criterion(postnet_output, linear_input,
|
|
||||||
mel_lengths)
|
|
||||||
else:
|
|
||||||
postnet_loss = criterion(postnet_output, mel_input,
|
|
||||||
mel_lengths)
|
|
||||||
else:
|
|
||||||
decoder_loss = criterion(decoder_output, mel_input)
|
|
||||||
if c.model in ["Tacotron", "TacotronGST"]:
|
|
||||||
postnet_loss = criterion(postnet_output, linear_input)
|
|
||||||
else:
|
|
||||||
postnet_loss = criterion(postnet_output, mel_input)
|
|
||||||
loss = decoder_loss + postnet_loss + stop_loss
|
|
||||||
|
|
||||||
# backward decoder loss
|
|
||||||
if c.bidirectional_decoder:
|
|
||||||
if c.loss_masking:
|
|
||||||
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input, mel_lengths)
|
|
||||||
else:
|
|
||||||
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input)
|
|
||||||
decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_backward_output, dims=(1, )), decoder_output)
|
|
||||||
loss += decoder_backward_loss + decoder_c_loss
|
|
||||||
keep_avg.update_values({'avg_decoder_b_loss': decoder_backward_loss.item(), 'avg_decoder_c_loss': decoder_c_loss.item()})
|
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
|
||||||
epoch_time += step_time
|
|
||||||
|
|
||||||
# compute alignment score
|
|
||||||
align_score = alignment_diagonal_score(alignments)
|
|
||||||
keep_avg.update_value('avg_align_score', align_score)
|
|
||||||
|
|
||||||
# aggregate losses from processes
|
|
||||||
if num_gpus > 1:
|
|
||||||
postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
|
|
||||||
decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
|
|
||||||
if c.stopnet:
|
|
||||||
stop_loss = reduce_tensor(stop_loss.data, num_gpus)
|
|
||||||
|
|
||||||
keep_avg.update_values({
|
|
||||||
'avg_postnet_loss':
|
|
||||||
float(postnet_loss.item()),
|
|
||||||
'avg_decoder_loss':
|
|
||||||
float(decoder_loss.item()),
|
|
||||||
'avg_stop_loss':
|
|
||||||
float(stop_loss.item()),
|
|
||||||
})
|
|
||||||
|
|
||||||
if num_iter % c.print_step == 0:
|
|
||||||
print(
|
|
||||||
" | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} "
|
|
||||||
"StopLoss: {:.5f} - {:.5f} AlignScore: {:.4f} : {:.4f}"
|
|
||||||
.format(loss.item(), postnet_loss.item(),
|
|
||||||
keep_avg['avg_postnet_loss'],
|
|
||||||
decoder_loss.item(),
|
|
||||||
keep_avg['avg_decoder_loss'], stop_loss.item(),
|
|
||||||
keep_avg['avg_stop_loss'], align_score,
|
|
||||||
keep_avg['avg_align_score']),
|
|
||||||
flush=True)
|
|
||||||
|
|
||||||
if args.rank == 0:
|
|
||||||
# Diagnostic visualizations
|
|
||||||
idx = np.random.randint(mel_input.shape[0])
|
|
||||||
const_spec = postnet_output[idx].data.cpu().numpy()
|
|
||||||
gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
|
|
||||||
"Tacotron", "TacotronGST"
|
|
||||||
] else mel_input[idx].data.cpu().numpy()
|
|
||||||
align_img = alignments[idx].data.cpu().numpy()
|
|
||||||
|
|
||||||
eval_figures = {
|
|
||||||
"prediction": plot_spectrogram(const_spec, ap),
|
|
||||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
|
||||||
"alignment": plot_alignment(align_img)
|
|
||||||
}
|
|
||||||
|
|
||||||
# Sample audio
|
|
||||||
if c.model in ["Tacotron", "TacotronGST"]:
|
if c.model in ["Tacotron", "TacotronGST"]:
|
||||||
eval_audio = ap.inv_spectrogram(const_spec.T)
|
postnet_loss = criterion(postnet_output, linear_input,
|
||||||
|
mel_lengths)
|
||||||
else:
|
else:
|
||||||
eval_audio = ap.inv_mel_spectrogram(const_spec.T)
|
postnet_loss = criterion(postnet_output, mel_input,
|
||||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
|
mel_lengths)
|
||||||
c.audio["sample_rate"])
|
else:
|
||||||
|
decoder_loss = criterion(decoder_output, mel_input)
|
||||||
|
if c.model in ["Tacotron", "TacotronGST"]:
|
||||||
|
postnet_loss = criterion(postnet_output, linear_input)
|
||||||
|
else:
|
||||||
|
postnet_loss = criterion(postnet_output, mel_input)
|
||||||
|
loss = decoder_loss + postnet_loss + stop_loss
|
||||||
|
|
||||||
# Plot Validation Stats
|
# backward decoder loss
|
||||||
epoch_stats = {
|
if c.bidirectional_decoder:
|
||||||
"loss_postnet": keep_avg['avg_postnet_loss'],
|
if c.loss_masking:
|
||||||
"loss_decoder": keep_avg['avg_decoder_loss'],
|
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input, mel_lengths)
|
||||||
"stop_loss": keep_avg['avg_stop_loss'],
|
else:
|
||||||
"alignment_score": keep_avg['avg_align_score']
|
decoder_backward_loss = criterion(torch.flip(decoder_backward_output, dims=(1, )), mel_input)
|
||||||
}
|
decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_backward_output, dims=(1, )), decoder_output)
|
||||||
|
loss += decoder_backward_loss + decoder_c_loss
|
||||||
|
keep_avg.update_values({'avg_decoder_b_loss': decoder_backward_loss.item(), 'avg_decoder_c_loss': decoder_c_loss.item()})
|
||||||
|
|
||||||
if c.bidirectional_decoder:
|
step_time = time.time() - start_time
|
||||||
epoch_stats['loss_decoder_backward'] = keep_avg['avg_decoder_b_loss']
|
epoch_time += step_time
|
||||||
align_b_img = alignments_backward[idx].data.cpu().numpy()
|
|
||||||
eval_figures['alignment_backward'] = plot_alignment(align_b_img)
|
# compute alignment score
|
||||||
tb_logger.tb_eval_stats(global_step, epoch_stats)
|
align_score = alignment_diagonal_score(alignments)
|
||||||
tb_logger.tb_eval_figures(global_step, eval_figures)
|
keep_avg.update_value('avg_align_score', align_score)
|
||||||
|
|
||||||
|
# aggregate losses from processes
|
||||||
|
if num_gpus > 1:
|
||||||
|
postnet_loss = reduce_tensor(postnet_loss.data, num_gpus)
|
||||||
|
decoder_loss = reduce_tensor(decoder_loss.data, num_gpus)
|
||||||
|
if c.stopnet:
|
||||||
|
stop_loss = reduce_tensor(stop_loss.data, num_gpus)
|
||||||
|
|
||||||
|
keep_avg.update_values({
|
||||||
|
'avg_postnet_loss':
|
||||||
|
float(postnet_loss.item()),
|
||||||
|
'avg_decoder_loss':
|
||||||
|
float(decoder_loss.item()),
|
||||||
|
'avg_stop_loss':
|
||||||
|
float(stop_loss.item()),
|
||||||
|
})
|
||||||
|
|
||||||
|
if num_iter % c.print_step == 0:
|
||||||
|
print(
|
||||||
|
" | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} "
|
||||||
|
"StopLoss: {:.5f} - {:.5f} AlignScore: {:.4f} : {:.4f}"
|
||||||
|
.format(loss.item(), postnet_loss.item(),
|
||||||
|
keep_avg['avg_postnet_loss'],
|
||||||
|
decoder_loss.item(),
|
||||||
|
keep_avg['avg_decoder_loss'], stop_loss.item(),
|
||||||
|
keep_avg['avg_stop_loss'], align_score,
|
||||||
|
keep_avg['avg_align_score']),
|
||||||
|
flush=True)
|
||||||
|
|
||||||
|
if args.rank == 0:
|
||||||
|
# Diagnostic visualizations
|
||||||
|
idx = np.random.randint(mel_input.shape[0])
|
||||||
|
const_spec = postnet_output[idx].data.cpu().numpy()
|
||||||
|
gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
|
||||||
|
"Tacotron", "TacotronGST"
|
||||||
|
] else mel_input[idx].data.cpu().numpy()
|
||||||
|
align_img = alignments[idx].data.cpu().numpy()
|
||||||
|
|
||||||
|
eval_figures = {
|
||||||
|
"prediction": plot_spectrogram(const_spec, ap),
|
||||||
|
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||||
|
"alignment": plot_alignment(align_img)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Sample audio
|
||||||
|
if c.model in ["Tacotron", "TacotronGST"]:
|
||||||
|
eval_audio = ap.inv_spectrogram(const_spec.T)
|
||||||
|
else:
|
||||||
|
eval_audio = ap.inv_mel_spectrogram(const_spec.T)
|
||||||
|
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
|
||||||
|
c.audio["sample_rate"])
|
||||||
|
|
||||||
|
# Plot Validation Stats
|
||||||
|
epoch_stats = {
|
||||||
|
"loss_postnet": keep_avg['avg_postnet_loss'],
|
||||||
|
"loss_decoder": keep_avg['avg_decoder_loss'],
|
||||||
|
"stop_loss": keep_avg['avg_stop_loss'],
|
||||||
|
"alignment_score": keep_avg['avg_align_score']
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.bidirectional_decoder:
|
||||||
|
epoch_stats['loss_decoder_backward'] = keep_avg['avg_decoder_b_loss']
|
||||||
|
align_b_img = alignments_backward[idx].data.cpu().numpy()
|
||||||
|
eval_figures['alignment_backward'] = plot_alignment(align_b_img)
|
||||||
|
tb_logger.tb_eval_stats(global_step, epoch_stats)
|
||||||
|
tb_logger.tb_eval_figures(global_step, eval_figures)
|
||||||
|
|
||||||
if args.rank == 0 and epoch > c.test_delay_epochs:
|
if args.rank == 0 and epoch > c.test_delay_epochs:
|
||||||
if c.test_sentences_file is None:
|
if c.test_sentences_file is None:
|
||||||
|
|
Loading…
Reference in New Issue