diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py index d36704fc..067a166f 100644 --- a/TTS/bin/train_vocoder_gan.py +++ b/TTS/bin/train_vocoder_gan.py @@ -14,9 +14,8 @@ from TTS.utils.arguments import parse_arguments, process_args from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import (KeepAverage, count_parameters, remove_experiment_folder, set_init_dict) - -from TTS.utils.radam import RAdam - +from TTS.utils.io import copy_model_files, load_config +from TTS.utils.tensorboard_logger import TensorboardLogger from TTS.utils.training import setup_torch_training_env from TTS.vocoder.datasets.gan_dataset import GANDataset from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data @@ -161,8 +160,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, torch.nn.utils.clip_grad_norm_(model_G.parameters(), c.gen_clip_grad) optimizer_G.step() - if scheduler_G is not None: - scheduler_G.step() loss_dict = dict() for key, value in loss_G_dict.items(): @@ -221,8 +218,6 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, torch.nn.utils.clip_grad_norm_(model_D.parameters(), c.disc_clip_grad) optimizer_D.step() - if scheduler_D is not None: - scheduler_D.step() for key, value in loss_D_dict.items(): if isinstance(value, (int, float)): @@ -293,7 +288,12 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, {'train/audio': sample_voice}, c.audio["sample_rate"]) end_time = time.time() - torch.cuda.empty_cache() + + if scheduler_G is not None: + scheduler_G.step() + + if scheduler_D is not None: + scheduler_D.step() # print epoch stats c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)