mirror of https://github.com/coqui-ai/TTS.git
linter fix
parent
566c2a4678
commit
78464f1ead
8
train.py
8
train.py
|
@ -368,13 +368,13 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
|||
stop_tokens, stop_targets) if c.stopnet else torch.zeros(1)
|
||||
if c.loss_masking:
|
||||
decoder_loss = criterion(decoder_output, mel_input,
|
||||
mel_lengths)
|
||||
mel_lengths)
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
postnet_loss = criterion(postnet_output, linear_input,
|
||||
mel_lengths)
|
||||
mel_lengths)
|
||||
else:
|
||||
postnet_loss = criterion(postnet_output, mel_input,
|
||||
mel_lengths)
|
||||
mel_lengths)
|
||||
else:
|
||||
decoder_loss = criterion(decoder_output, mel_input)
|
||||
if c.model in ["Tacotron", "TacotronGST"]:
|
||||
|
@ -449,7 +449,7 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
|||
else:
|
||||
eval_audio = ap.inv_mel_spectrogram(const_spec.T)
|
||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
|
||||
c.audio["sample_rate"])
|
||||
c.audio["sample_rate"])
|
||||
|
||||
# Plot Validation Stats
|
||||
epoch_stats = {
|
||||
|
|
Loading…
Reference in New Issue