mirror of https://github.com/coqui-ai/TTS.git
commit
b47d9c6e36
|
@ -8,17 +8,17 @@ import traceback
|
|||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from trainer.io import copy_model_files, save_best_model, save_checkpoint
|
||||
from trainer.torch import NoamLR
|
||||
from trainer.trainer_utils import get_optimizer
|
||||
|
||||
from TTS.encoder.dataset import EncoderDataset
|
||||
from TTS.encoder.utils.generic_utils import save_best_model, save_checkpoint, setup_encoder_model
|
||||
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||
from TTS.encoder.utils.training import init_training
|
||||
from TTS.encoder.utils.visual import plot_embeddings
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder
|
||||
from TTS.utils.io import copy_model_files
|
||||
from TTS.utils.samplers import PerfectBatchSampler
|
||||
from TTS.utils.training import check_update
|
||||
|
||||
|
@ -222,7 +222,9 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
|
|||
|
||||
if global_step % c.save_step == 0:
|
||||
# save model
|
||||
save_checkpoint(model, optimizer, criterion, loss.item(), OUT_PATH, global_step, epoch)
|
||||
save_checkpoint(
|
||||
c, model, optimizer, None, global_step, epoch, OUT_PATH, criterion=criterion.state_dict()
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
|
@ -245,7 +247,18 @@ def train(model, optimizer, scheduler, criterion, data_loader, eval_data_loader,
|
|||
flush=True,
|
||||
)
|
||||
# save the best checkpoint
|
||||
best_loss = save_best_model(model, optimizer, criterion, eval_loss, best_loss, OUT_PATH, global_step, epoch)
|
||||
best_loss = save_best_model(
|
||||
eval_loss,
|
||||
best_loss,
|
||||
c,
|
||||
model,
|
||||
optimizer,
|
||||
None,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
criterion=criterion.state_dict(),
|
||||
)
|
||||
model.train()
|
||||
|
||||
return best_loss, global_step
|
||||
|
@ -276,7 +289,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
if c.loss == "softmaxproto" and c.model != "speaker_encoder":
|
||||
c.map_classid_to_classname = map_classid_to_classname
|
||||
copy_model_files(c, OUT_PATH)
|
||||
copy_model_files(c, OUT_PATH, new_fields={})
|
||||
|
||||
if args.restore_path:
|
||||
criterion, args.restore_step = model.load_checkpoint(
|
||||
|
|
|
@ -1,15 +1,12 @@
|
|||
import datetime
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
|
||||
from TTS.encoder.models.lstm import LSTMSpeakerEncoder
|
||||
from TTS.encoder.models.resnet import ResNetSpeakerEncoder
|
||||
from TTS.utils.io import save_fsspec
|
||||
|
||||
|
||||
class AugmentWAV(object):
|
||||
|
@ -118,11 +115,6 @@ class AugmentWAV(object):
|
|||
return self.additive_noise(noise_type, audio)
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_encoder_model(config: "Coqpit"):
|
||||
if config.model_params["model_name"].lower() == "lstm":
|
||||
model = LSTMSpeakerEncoder(
|
||||
|
@ -142,41 +134,3 @@ def setup_encoder_model(config: "Coqpit"):
|
|||
audio_config=config.audio,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, criterion, model_loss, out_path, current_step, epoch):
|
||||
checkpoint_path = "checkpoint_{}.pth".format(current_step)
|
||||
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
||||
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
|
||||
|
||||
new_state_dict = model.state_dict()
|
||||
state = {
|
||||
"model": new_state_dict,
|
||||
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
||||
"criterion": criterion.state_dict(),
|
||||
"step": current_step,
|
||||
"epoch": epoch,
|
||||
"loss": model_loss,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
save_fsspec(state, checkpoint_path)
|
||||
|
||||
|
||||
def save_best_model(model, optimizer, criterion, model_loss, best_loss, out_path, current_step, epoch):
|
||||
if model_loss < best_loss:
|
||||
new_state_dict = model.state_dict()
|
||||
state = {
|
||||
"model": new_state_dict,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"criterion": criterion.state_dict(),
|
||||
"step": current_step,
|
||||
"epoch": epoch,
|
||||
"loss": model_loss,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
best_loss = model_loss
|
||||
bestmodel_path = "best_model.pth"
|
||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
|
||||
save_fsspec(state, bestmodel_path)
|
||||
return best_loss
|
||||
|
|
|
@ -1,38 +0,0 @@
|
|||
import datetime
|
||||
import os
|
||||
|
||||
from TTS.utils.io import save_fsspec
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, model_loss, out_path, current_step):
|
||||
checkpoint_path = "checkpoint_{}.pth".format(current_step)
|
||||
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
||||
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
|
||||
|
||||
new_state_dict = model.state_dict()
|
||||
state = {
|
||||
"model": new_state_dict,
|
||||
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
||||
"step": current_step,
|
||||
"loss": model_loss,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
save_fsspec(state, checkpoint_path)
|
||||
|
||||
|
||||
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
|
||||
if model_loss < best_loss:
|
||||
new_state_dict = model.state_dict()
|
||||
state = {
|
||||
"model": new_state_dict,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"step": current_step,
|
||||
"loss": model_loss,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
best_loss = model_loss
|
||||
bestmodel_path = "best_model.pth"
|
||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
|
||||
save_fsspec(state, bestmodel_path)
|
||||
return best_loss
|
|
@ -3,13 +3,13 @@ from dataclasses import dataclass, field
|
|||
|
||||
from coqpit import Coqpit
|
||||
from trainer import TrainerArgs, get_last_checkpoint
|
||||
from trainer.io import copy_model_files
|
||||
from trainer.logging import logger_factory
|
||||
from trainer.logging.console_logger import ConsoleLogger
|
||||
|
||||
from TTS.config import load_config, register_config
|
||||
from TTS.tts.utils.text.characters import parse_symbols
|
||||
from TTS.utils.generic_utils import get_experiment_folder_path, get_git_branch
|
||||
from TTS.utils.io import copy_model_files
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
146
TTS/utils/io.py
146
TTS/utils/io.py
|
@ -1,13 +1,9 @@
|
|||
import datetime
|
||||
import json
|
||||
import os
|
||||
import pickle as pickle_tts
|
||||
import shutil
|
||||
from typing import Any, Callable, Dict, Union
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.utils.generic_utils import get_user_data_dir
|
||||
|
||||
|
@ -28,34 +24,6 @@ class AttrDict(dict):
|
|||
self.__dict__ = self
|
||||
|
||||
|
||||
def copy_model_files(config: Coqpit, out_path, new_fields=None):
|
||||
"""Copy config.json and other model files to training folder and add
|
||||
new fields.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Coqpit config defining the training run.
|
||||
out_path (str): output path to copy the file.
|
||||
new_fields (dict): new fileds to be added or edited
|
||||
in the config file.
|
||||
"""
|
||||
copy_config_path = os.path.join(out_path, "config.json")
|
||||
# add extra information fields
|
||||
if new_fields:
|
||||
config.update(new_fields, allow_new=True)
|
||||
# TODO: Revert to config.save_json() once Coqpit supports arbitrary paths.
|
||||
with fsspec.open(copy_config_path, "w", encoding="utf8") as f:
|
||||
json.dump(config.to_dict(), f, indent=4)
|
||||
|
||||
# copy model stats file if available
|
||||
if config.audio.stats_path is not None:
|
||||
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
|
||||
filesystem = fsspec.get_mapper(copy_stats_path).fs
|
||||
if not filesystem.exists(copy_stats_path):
|
||||
with fsspec.open(config.audio.stats_path, "rb") as source_file:
|
||||
with fsspec.open(copy_stats_path, "wb") as target_file:
|
||||
shutil.copyfileobj(source_file, target_file)
|
||||
|
||||
|
||||
def load_fsspec(
|
||||
path: str,
|
||||
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None,
|
||||
|
@ -100,117 +68,3 @@ def load_checkpoint(
|
|||
if eval:
|
||||
model.eval()
|
||||
return model, state
|
||||
|
||||
|
||||
def save_fsspec(state: Any, path: str, **kwargs):
|
||||
"""Like torch.save but can save to other locations (e.g. s3:// , gs://).
|
||||
|
||||
Args:
|
||||
state: State object to save
|
||||
path: Any path or url supported by fsspec.
|
||||
**kwargs: Keyword arguments forwarded to torch.save.
|
||||
"""
|
||||
with fsspec.open(path, "wb") as f:
|
||||
torch.save(state, f, **kwargs)
|
||||
|
||||
|
||||
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
|
||||
if hasattr(model, "module"):
|
||||
model_state = model.module.state_dict()
|
||||
else:
|
||||
model_state = model.state_dict()
|
||||
if isinstance(optimizer, list):
|
||||
optimizer_state = [optim.state_dict() for optim in optimizer]
|
||||
elif optimizer.__class__.__name__ == "CapacitronOptimizer":
|
||||
optimizer_state = [optimizer.primary_optimizer.state_dict(), optimizer.secondary_optimizer.state_dict()]
|
||||
else:
|
||||
optimizer_state = optimizer.state_dict() if optimizer is not None else None
|
||||
|
||||
if isinstance(scaler, list):
|
||||
scaler_state = [s.state_dict() for s in scaler]
|
||||
else:
|
||||
scaler_state = scaler.state_dict() if scaler is not None else None
|
||||
|
||||
if isinstance(config, Coqpit):
|
||||
config = config.to_dict()
|
||||
|
||||
state = {
|
||||
"config": config,
|
||||
"model": model_state,
|
||||
"optimizer": optimizer_state,
|
||||
"scaler": scaler_state,
|
||||
"step": current_step,
|
||||
"epoch": epoch,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
state.update(kwargs)
|
||||
save_fsspec(state, output_path)
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
config,
|
||||
model,
|
||||
optimizer,
|
||||
scaler,
|
||||
current_step,
|
||||
epoch,
|
||||
output_folder,
|
||||
**kwargs,
|
||||
):
|
||||
file_name = "checkpoint_{}.pth".format(current_step)
|
||||
checkpoint_path = os.path.join(output_folder, file_name)
|
||||
print("\n > CHECKPOINT : {}".format(checkpoint_path))
|
||||
save_model(
|
||||
config,
|
||||
model,
|
||||
optimizer,
|
||||
scaler,
|
||||
current_step,
|
||||
epoch,
|
||||
checkpoint_path,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def save_best_model(
|
||||
current_loss,
|
||||
best_loss,
|
||||
config,
|
||||
model,
|
||||
optimizer,
|
||||
scaler,
|
||||
current_step,
|
||||
epoch,
|
||||
out_path,
|
||||
keep_all_best=False,
|
||||
keep_after=10000,
|
||||
**kwargs,
|
||||
):
|
||||
if current_loss < best_loss:
|
||||
best_model_name = f"best_model_{current_step}.pth"
|
||||
checkpoint_path = os.path.join(out_path, best_model_name)
|
||||
print(" > BEST MODEL : {}".format(checkpoint_path))
|
||||
save_model(
|
||||
config,
|
||||
model,
|
||||
optimizer,
|
||||
scaler,
|
||||
current_step,
|
||||
epoch,
|
||||
checkpoint_path,
|
||||
model_loss=current_loss,
|
||||
**kwargs,
|
||||
)
|
||||
fs = fsspec.get_mapper(out_path).fs
|
||||
# only delete previous if current is saved successfully
|
||||
if not keep_all_best or (current_step < keep_after):
|
||||
model_names = fs.glob(os.path.join(out_path, "best_model*.pth"))
|
||||
for model_name in model_names:
|
||||
if os.path.basename(model_name) != best_model_name:
|
||||
fs.rm(model_name)
|
||||
# create a shortcut which always points to the currently best model
|
||||
shortcut_name = "best_model.pth"
|
||||
shortcut_path = os.path.join(out_path, shortcut_name)
|
||||
fs.copy(checkpoint_path, shortcut_path)
|
||||
best_loss = current_loss
|
||||
return best_loss
|
||||
|
|
|
@ -3,11 +3,11 @@ import unittest
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from trainer.io import save_checkpoint
|
||||
|
||||
from tests import get_tests_input_path
|
||||
from TTS.config import load_config
|
||||
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||
from TTS.encoder.utils.io import save_checkpoint
|
||||
from TTS.tts.utils.managers import EmbeddingManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
@ -31,7 +31,7 @@ class EmbeddingManagerTest(unittest.TestCase):
|
|||
|
||||
# create a dummy speaker encoder
|
||||
model = setup_encoder_model(config)
|
||||
save_checkpoint(model, None, None, get_tests_input_path(), 0)
|
||||
save_checkpoint(config, model, None, None, 0, 0, get_tests_input_path())
|
||||
|
||||
# load audio processor and speaker encoder
|
||||
manager = EmbeddingManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path)
|
||||
|
|
|
@ -3,11 +3,11 @@ import unittest
|
|||
|
||||
import numpy as np
|
||||
import torch
|
||||
from trainer.io import save_checkpoint
|
||||
|
||||
from tests import get_tests_input_path
|
||||
from TTS.config import load_config
|
||||
from TTS.encoder.utils.generic_utils import setup_encoder_model
|
||||
from TTS.encoder.utils.io import save_checkpoint
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
@ -30,7 +30,7 @@ class SpeakerManagerTest(unittest.TestCase):
|
|||
|
||||
# create a dummy speaker encoder
|
||||
model = setup_encoder_model(config)
|
||||
save_checkpoint(model, None, None, get_tests_input_path(), 0)
|
||||
save_checkpoint(config, model, None, None, 0, 0, get_tests_input_path())
|
||||
|
||||
# load audio processor and speaker encoder
|
||||
ap = AudioProcessor(**config.audio)
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
import os
|
||||
import unittest
|
||||
|
||||
from trainer.io import save_checkpoint
|
||||
|
||||
from tests import get_tests_input_path
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.utils.io import save_checkpoint
|
||||
from TTS.utils.synthesizer import Synthesizer
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue