mirror of https://github.com/coqui-ai/TTS.git
move scheduler updates to the end of the epoch
parent
2a872c98aa
commit
3fb78c004a
|
@ -14,9 +14,8 @@ from TTS.utils.arguments import parse_arguments, process_args
|
|||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
|
||||
from TTS.utils.radam import RAdam
|
||||
|
||||
from TTS.utils.io import copy_model_files, load_config
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.training import setup_torch_training_env
|
||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
|
@ -161,8 +160,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
torch.nn.utils.clip_grad_norm_(model_G.parameters(),
|
||||
c.gen_clip_grad)
|
||||
optimizer_G.step()
|
||||
if scheduler_G is not None:
|
||||
scheduler_G.step()
|
||||
|
||||
loss_dict = dict()
|
||||
for key, value in loss_G_dict.items():
|
||||
|
@ -221,8 +218,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
torch.nn.utils.clip_grad_norm_(model_D.parameters(),
|
||||
c.disc_clip_grad)
|
||||
optimizer_D.step()
|
||||
if scheduler_D is not None:
|
||||
scheduler_D.step()
|
||||
|
||||
for key, value in loss_D_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
|
@ -293,7 +288,12 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
{'train/audio': sample_voice},
|
||||
c.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if scheduler_G is not None:
|
||||
scheduler_G.step()
|
||||
|
||||
if scheduler_D is not None:
|
||||
scheduler_D.step()
|
||||
|
||||
# print epoch stats
|
||||
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
||||
|
|
Loading…
Reference in New Issue