mirror of https://github.com/coqui-ai/TTS.git
bug fixes on train.py
parent
96b48c003a
commit
17b65d5cde
14
train.py
14
train.py
|
@ -43,7 +43,9 @@ def setup_loader(is_val=False):
|
|||
preprocessor=preprocessor,
|
||||
ap=ap,
|
||||
batch_group_size=0 if is_val else 8 * c.batch_size,
|
||||
min_seq_len=0 if is_val else c.min_seq_len)
|
||||
min_seq_len=0 if is_val else c.min_seq_len,
|
||||
max_seq_len=float("inf") if is_val else c.max_seq_len
|
||||
cached=False if c.dataset ~= "tts_cache" else True)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=c.eval_batch_size if is_val else c.batch_size,
|
||||
|
@ -164,8 +166,8 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st,
|
|||
grad_norm, grad_norm_st, avg_text_length, avg_spec_length, step_time, current_lr),
|
||||
flush=True)
|
||||
|
||||
avg_linear_loss += linear_loss.item()
|
||||
avg_mel_loss += mel_loss.item()
|
||||
avg_linear_loss += float(linear_loss.item())
|
||||
avg_mel_loss += float(mel_loss.item())
|
||||
avg_stop_loss += stop_loss.item()
|
||||
avg_step_time += step_time
|
||||
|
||||
|
@ -198,7 +200,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st,
|
|||
# Sample audio
|
||||
tb_logger.tb_train_audios(current_step,
|
||||
{'TrainAudio': ap.inv_spectrogram(const_spec.T)},
|
||||
c.sample_rate)
|
||||
c.audio["sample_rate"])
|
||||
|
||||
avg_linear_loss /= (num_iter + 1)
|
||||
avg_mel_loss /= (num_iter + 1)
|
||||
|
@ -295,8 +297,8 @@ def evaluate(model, criterion, criterion_st, ap, current_step):
|
|||
stop_loss.item()),
|
||||
flush=True)
|
||||
|
||||
avg_linear_loss += linear_loss.item()
|
||||
avg_mel_loss += mel_loss.item()
|
||||
avg_linear_loss += float(linear_loss.item())
|
||||
avg_mel_loss += float(mel_loss.item())
|
||||
avg_stop_loss += stop_loss.item()
|
||||
|
||||
# Diagnostic visualizations
|
||||
|
|
Loading…
Reference in New Issue