move scheduler updates to the end of the epoch

pull/422/head
Eren Gölge 2021-04-08 11:11:55 +02:00
parent 2a872c98aa
commit 3fb78c004a
1 changed files with 8 additions and 8 deletions

View File

@ -14,9 +14,8 @@ from TTS.utils.arguments import parse_arguments, process_args
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
remove_experiment_folder, set_init_dict)
from TTS.utils.radam import RAdam
from TTS.utils.io import copy_model_files, load_config
from TTS.utils.tensorboard_logger import TensorboardLogger
from TTS.utils.training import setup_torch_training_env
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
@ -161,8 +160,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
torch.nn.utils.clip_grad_norm_(model_G.parameters(),
c.gen_clip_grad)
optimizer_G.step()
if scheduler_G is not None:
scheduler_G.step()
loss_dict = dict()
for key, value in loss_G_dict.items():
@ -221,8 +218,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
torch.nn.utils.clip_grad_norm_(model_D.parameters(),
c.disc_clip_grad)
optimizer_D.step()
if scheduler_D is not None:
scheduler_D.step()
for key, value in loss_D_dict.items():
if isinstance(value, (int, float)):
@ -293,7 +288,12 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
{'train/audio': sample_voice},
c.audio["sample_rate"])
end_time = time.time()
torch.cuda.empty_cache()
if scheduler_G is not None:
scheduler_G.step()
if scheduler_D is not None:
scheduler_D.step()
# print epoch stats
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)