mirror of https://github.com/coqui-ai/TTS.git
prompt data loade time per iteartion
parent
1827f77752
commit
713b3df792
9
train.py
9
train.py
|
@ -179,6 +179,9 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
optimizer_st.step()
|
||||
else:
|
||||
grad_norm_st = 0
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
if current_step % c.print_step == 0:
|
||||
print(
|
||||
|
@ -242,9 +245,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
tb_logger.tb_train_audios(current_step,
|
||||
{'TrainAudio': train_audio},
|
||||
c.audio["sample_rate"])
|
||||
|
||||
step_time = end_time - start_time
|
||||
epoch_time += step_time
|
||||
end_time = time.time()
|
||||
|
||||
avg_postnet_loss /= (num_iter + 1)
|
||||
avg_decoder_loss /= (num_iter + 1)
|
||||
|
@ -274,8 +275,6 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
tb_logger.tb_train_epoch_stats(current_step, epoch_stats)
|
||||
if c.tb_model_param_stats:
|
||||
tb_logger.tb_model_weights(model, current_step)
|
||||
|
||||
end_time = time.time()
|
||||
return avg_postnet_loss, current_step
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue