update trainer.py for better logging handling, restoring models and

rename init_ functions with get_
pull/602/head
Eren Gölge 2021-05-27 10:24:26 +02:00
parent fb9289d365
commit 8dfd4c91ff
2 changed files with 15 additions and 13 deletions

View File

@ -10,7 +10,11 @@ def main():
# try:
args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(
sys.argv)
trainer = TrainerTTS(args, config, c_logger, tb_logger, output_path=OUT_PATH)
trainer = TrainerTTS(args,
config,
c_logger,
tb_logger,
output_path=OUT_PATH)
trainer.fit()
# except KeyboardInterrupt:
# remove_experiment_folder(OUT_PATH)

View File

@ -1,12 +1,13 @@
# -*- coding: utf-8 -*-
import importlib
import logging
import os
import sys
import time
import traceback
from logging import StreamHandler
from random import randrange
import logging
import importlib
import numpy as np
import torch
@ -16,19 +17,19 @@ from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.tts.datasets import load_meta_data, TTSDataset
from TTS.tts.datasets import TTSDataset, load_meta_data
from TTS.tts.layers import setup_loss
from TTS.tts.models import setup_model
from TTS.tts.utils.io import save_best_model, save_checkpoint
from TTS.tts.utils.speakers import SpeakerManager
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.arguments import init_training
from TTS.tts.utils.visual import plot_spectrogram, plot_alignment
from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import init_distributed, reduce_tensor
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict, find_module
from TTS.utils.training import setup_torch_training_env, check_update
from TTS.utils.generic_utils import KeepAverage, count_parameters, find_module, remove_experiment_folder, set_init_dict
from TTS.utils.training import check_update, setup_torch_training_env
@dataclass
@ -140,9 +141,8 @@ class TrainerTTS:
self.config, args.restore_path, self.model, self.optimizer,
self.scaler)
if self.use_cuda:
self.model.cuda()
self.criterion.cuda()
# setup scheduler
self.scheduler = self.get_scheduler(self.config, self.optimizer)
# DISTRUBUTED
if self.num_gpus > 1:
@ -150,8 +150,7 @@ class TrainerTTS:
# count model size
num_params = count_parameters(self.model)
logging.info("\n > Model has {} parameters".format(num_params),
flush=True)
logging.info("\n > Model has {} parameters".format(num_params))
@staticmethod
def get_model(num_chars: int, num_speakers: int, config: Coqpit,
@ -241,7 +240,6 @@ class TrainerTTS:
try:
logging.info(" > Restoring Model...")
model.load_state_dict(checkpoint["model"])
# optimizer restore
logging.info(" > Restoring Optimizer...")
optimizer.load_state_dict(checkpoint["optimizer"])
if "scaler" in checkpoint and config.mixed_precision: