mirror of https://github.com/coqui-ai/TTS.git
linter fix
parent
609d8efa69
commit
d45d963dc1
15
train.py
15
train.py
|
@ -305,9 +305,9 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
|||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||
model.eval()
|
||||
epoch_time = 0
|
||||
eval_values_dict = {'avg_postnet_loss' : 0,
|
||||
'avg_decoder_loss' : 0,
|
||||
'avg_stop_loss' : 0,
|
||||
eval_values_dict = {'avg_postnet_loss': 0,
|
||||
'avg_decoder_loss': 0,
|
||||
'avg_stop_loss': 0,
|
||||
'avg_align_score': 0}
|
||||
keep_avg = KeepAverage()
|
||||
keep_avg.add_values(eval_values_dict)
|
||||
|
@ -401,14 +401,15 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch):
|
|||
if c.stopnet:
|
||||
stop_loss = reduce_tensor(stop_loss.data, num_gpus)
|
||||
|
||||
keep_avg.update_values({'avg_postnet_loss' : float(postnet_loss.item()),
|
||||
'avg_decoder_loss' : float(decoder_loss.item()),
|
||||
'avg_stop_loss' : float(stop_loss.item())})
|
||||
keep_avg.update_values({'avg_postnet_loss': float(postnet_loss.item()),
|
||||
'avg_decoder_loss': float(decoder_loss.item()),
|
||||
'avg_stop_loss': float(stop_loss.item())})
|
||||
|
||||
if num_iter % c.print_step == 0:
|
||||
print(
|
||||
" | > TotalLoss: {:.5f} PostnetLoss: {:.5f} - {:.5f} DecoderLoss:{:.5f} - {:.5f} "
|
||||
"StopLoss: {:.5f} - {:.5f} AlignScore: {:.4f} : {:.4f}".format(loss.item(),
|
||||
"StopLoss: {:.5f} - {:.5f} AlignScore: {:.4f} : {:.4f}".format(
|
||||
loss.item(),
|
||||
postnet_loss.item(), keep_avg['avg_postnet_loss'],
|
||||
decoder_loss.item(), keep_avg['avg_decoder_loss'],
|
||||
stop_loss.item(), keep_avg['avg_stop_loss'],
|
||||
|
|
|
@ -31,7 +31,8 @@ def load_config(config_path):
|
|||
def get_git_branch():
|
||||
try:
|
||||
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
||||
current = next(line for line in out.split("\n") if line.startswith("*"))
|
||||
current = next(line for line in out.split(
|
||||
"\n") if line.startswith("*"))
|
||||
current.replace("* ", "")
|
||||
except subprocess.CalledProcessError:
|
||||
current = "inside_docker"
|
||||
|
@ -333,7 +334,8 @@ class KeepAverage():
|
|||
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value
|
||||
self.iters[name] += 1
|
||||
else:
|
||||
self.avg_values[name] = self.avg_values[name] * self.iters[name] + value
|
||||
self.avg_values[name] = self.avg_values[name] * \
|
||||
self.iters[name] + value
|
||||
self.iters[name] += 1
|
||||
self.avg_values[name] /= self.iters[name]
|
||||
|
||||
|
@ -344,4 +346,3 @@ class KeepAverage():
|
|||
def update_values(self, value_dict):
|
||||
for key, value in value_dict.items():
|
||||
self.update_value(key, value)
|
||||
|
||||
|
|
|
@ -1,6 +1,3 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def alignment_diagonal_score(alignments):
|
||||
"""
|
||||
|
@ -12,8 +9,3 @@ def alignment_diagonal_score(alignments):
|
|||
alignments : batch x decoder_steps x encoder_steps
|
||||
"""
|
||||
return alignments.max(dim=1)[0].mean(dim=1).mean(dim=0)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue