mirror of https://github.com/coqui-ai/TTS.git
update trainer.py for better logging handling, restoring models and
rename init_ functions with get_pull/602/head
parent
fb9289d365
commit
8dfd4c91ff
|
@ -10,7 +10,11 @@ def main():
|
|||
# try:
|
||||
args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(
|
||||
sys.argv)
|
||||
trainer = TrainerTTS(args, config, c_logger, tb_logger, output_path=OUT_PATH)
|
||||
trainer = TrainerTTS(args,
|
||||
config,
|
||||
c_logger,
|
||||
tb_logger,
|
||||
output_path=OUT_PATH)
|
||||
trainer.fit()
|
||||
# except KeyboardInterrupt:
|
||||
# remove_experiment_folder(OUT_PATH)
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from logging import StreamHandler
|
||||
from random import randrange
|
||||
import logging
|
||||
import importlib
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -16,19 +17,19 @@ from torch.nn.parallel import DistributedDataParallel as DDP_th
|
|||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.tts.datasets import load_meta_data, TTSDataset
|
||||
from TTS.tts.datasets import TTSDataset, load_meta_data
|
||||
from TTS.tts.layers import setup_loss
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.arguments import init_training
|
||||
from TTS.tts.utils.visual import plot_spectrogram, plot_alignment
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.distribute import init_distributed, reduce_tensor
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict, find_module
|
||||
from TTS.utils.training import setup_torch_training_env, check_update
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, find_module, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.training import check_update, setup_torch_training_env
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -140,9 +141,8 @@ class TrainerTTS:
|
|||
self.config, args.restore_path, self.model, self.optimizer,
|
||||
self.scaler)
|
||||
|
||||
if self.use_cuda:
|
||||
self.model.cuda()
|
||||
self.criterion.cuda()
|
||||
# setup scheduler
|
||||
self.scheduler = self.get_scheduler(self.config, self.optimizer)
|
||||
|
||||
# DISTRUBUTED
|
||||
if self.num_gpus > 1:
|
||||
|
@ -150,8 +150,7 @@ class TrainerTTS:
|
|||
|
||||
# count model size
|
||||
num_params = count_parameters(self.model)
|
||||
logging.info("\n > Model has {} parameters".format(num_params),
|
||||
flush=True)
|
||||
logging.info("\n > Model has {} parameters".format(num_params))
|
||||
|
||||
@staticmethod
|
||||
def get_model(num_chars: int, num_speakers: int, config: Coqpit,
|
||||
|
@ -241,7 +240,6 @@ class TrainerTTS:
|
|||
try:
|
||||
logging.info(" > Restoring Model...")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
# optimizer restore
|
||||
logging.info(" > Restoring Optimizer...")
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
if "scaler" in checkpoint and config.mixed_precision:
|
||||
|
|
Loading…
Reference in New Issue