From 8dfd4c91ff300f83d6c7aed6b1fa2fc325859069 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Thu, 27 May 2021 10:24:26 +0200 Subject: [PATCH] update trainer.py for better logging handling, restoring models and rename init_ functions with get_ --- TTS/bin/train_tts.py | 6 +++++- TTS/trainer.py | 22 ++++++++++------------ 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/TTS/bin/train_tts.py b/TTS/bin/train_tts.py index 5058d341..7cc8a25f 100644 --- a/TTS/bin/train_tts.py +++ b/TTS/bin/train_tts.py @@ -10,7 +10,11 @@ def main(): # try: args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training( sys.argv) - trainer = TrainerTTS(args, config, c_logger, tb_logger, output_path=OUT_PATH) + trainer = TrainerTTS(args, + config, + c_logger, + tb_logger, + output_path=OUT_PATH) trainer.fit() # except KeyboardInterrupt: # remove_experiment_folder(OUT_PATH) diff --git a/TTS/trainer.py b/TTS/trainer.py index 3beb281f..6087f1bc 100644 --- a/TTS/trainer.py +++ b/TTS/trainer.py @@ -1,12 +1,13 @@ # -*- coding: utf-8 -*- +import importlib +import logging import os import sys import time import traceback +from logging import StreamHandler from random import randrange -import logging -import importlib import numpy as np import torch @@ -16,19 +17,19 @@ from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler -from TTS.tts.datasets import load_meta_data, TTSDataset +from TTS.tts.datasets import TTSDataset, load_meta_data from TTS.tts.layers import setup_loss from TTS.tts.models import setup_model from TTS.tts.utils.io import save_best_model, save_checkpoint from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols +from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.arguments import init_training -from TTS.tts.utils.visual import plot_spectrogram, plot_alignment from TTS.utils.audio import AudioProcessor from TTS.utils.distribute import init_distributed, reduce_tensor -from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict, find_module -from TTS.utils.training import setup_torch_training_env, check_update +from TTS.utils.generic_utils import KeepAverage, count_parameters, find_module, remove_experiment_folder, set_init_dict +from TTS.utils.training import check_update, setup_torch_training_env @dataclass @@ -140,9 +141,8 @@ class TrainerTTS: self.config, args.restore_path, self.model, self.optimizer, self.scaler) - if self.use_cuda: - self.model.cuda() - self.criterion.cuda() + # setup scheduler + self.scheduler = self.get_scheduler(self.config, self.optimizer) # DISTRUBUTED if self.num_gpus > 1: @@ -150,8 +150,7 @@ class TrainerTTS: # count model size num_params = count_parameters(self.model) - logging.info("\n > Model has {} parameters".format(num_params), - flush=True) + logging.info("\n > Model has {} parameters".format(num_params)) @staticmethod def get_model(num_chars: int, num_speakers: int, config: Coqpit, @@ -241,7 +240,6 @@ class TrainerTTS: try: logging.info(" > Restoring Model...") model.load_state_dict(checkpoint["model"]) - # optimizer restore logging.info(" > Restoring Optimizer...") optimizer.load_state_dict(checkpoint["optimizer"]) if "scaler" in checkpoint and config.mixed_precision: