mirror of https://github.com/coqui-ai/TTS.git
Merge branch 'pr/gerazov/642' into dev
commit
1f0385a343
|
@ -1,8 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Train Glow TTS model."""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
@ -14,10 +12,12 @@ import torch
|
|||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.utils.arguments import parse_arguments, process_args
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||
from TTS.tts.layers.losses import GlowTTSLoss
|
||||
from TTS.tts.utils.generic_utils import check_config_tts, setup_model
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.speakers import parse_speakers
|
||||
|
@ -25,18 +25,15 @@ from TTS.tts.utils.synthesis import synthesis
|
|||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.distribute import init_distributed, reduce_tensor
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
create_experiment_folder, get_git_branch,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.io import copy_model_files, load_config
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.training import NoamLR, setup_torch_training_env
|
||||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||
|
||||
|
||||
def setup_loader(ap, r, is_val=False, verbose=False):
|
||||
if is_val and not c.run_eval:
|
||||
loader = None
|
||||
|
@ -468,7 +465,6 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
return keep_avg.avg_values
|
||||
|
||||
|
||||
# FIXME: move args definition/parsing inside of main?
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
||||
|
@ -567,81 +563,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--continue_path',
|
||||
type=str,
|
||||
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
||||
default='',
|
||||
required='--config_path' not in sys.argv)
|
||||
parser.add_argument(
|
||||
'--restore_path',
|
||||
type=str,
|
||||
help='Model file to be restored. Use to finetune a model.',
|
||||
default='')
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
type=str,
|
||||
help='Path to config file for training.',
|
||||
required='--continue_path' not in sys.argv
|
||||
)
|
||||
parser.add_argument('--debug',
|
||||
type=bool,
|
||||
default=False,
|
||||
help='Do not verify commit integrity to run training.')
|
||||
|
||||
# DISTRUBUTED
|
||||
parser.add_argument(
|
||||
'--rank',
|
||||
type=int,
|
||||
default=0,
|
||||
help='DISTRIBUTED: process rank for distributed training.')
|
||||
parser.add_argument('--group_id',
|
||||
type=str,
|
||||
default="",
|
||||
help='DISTRIBUTED: process group id.')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.continue_path != '':
|
||||
args.output_path = args.continue_path
|
||||
args.config_path = os.path.join(args.continue_path, 'config.json')
|
||||
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
|
||||
latest_model_file = max(list_of_files, key=os.path.getctime)
|
||||
args.restore_path = latest_model_file
|
||||
print(f" > Training continues for {args.restore_path}")
|
||||
|
||||
# setup output paths and read configs
|
||||
c = load_config(args.config_path)
|
||||
# check_config(c)
|
||||
check_config_tts(c)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
if c.mixed_precision:
|
||||
print(" > Mixed precision enabled.")
|
||||
|
||||
OUT_PATH = args.continue_path
|
||||
if args.continue_path == '':
|
||||
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
|
||||
|
||||
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
||||
|
||||
c_logger = ConsoleLogger()
|
||||
|
||||
if args.rank == 0:
|
||||
os.makedirs(AUDIO_PATH, exist_ok=True)
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
new_fields["restore_path"] = args.restore_path
|
||||
new_fields["github_branch"] = get_git_branch()
|
||||
copy_model_files(c, args.config_path, OUT_PATH, new_fields)
|
||||
os.chmod(AUDIO_PATH, 0o775)
|
||||
os.chmod(OUT_PATH, 0o775)
|
||||
|
||||
LOG_DIR = OUT_PATH
|
||||
tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS')
|
||||
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, model_type='glow_tts')
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -11,6 +11,7 @@ import numpy as np
|
|||
from random import randrange
|
||||
|
||||
import torch
|
||||
from TTS.utils.arguments import parse_arguments, process_args
|
||||
# DISTRIBUTED
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -18,7 +19,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||
from TTS.tts.layers.losses import SpeedySpeechLoss
|
||||
from TTS.tts.utils.generic_utils import check_config_tts, setup_model
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.speakers import parse_speakers
|
||||
|
@ -26,14 +27,10 @@ from TTS.tts.utils.synthesis import synthesis
|
|||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.distribute import init_distributed, reduce_tensor
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
create_experiment_folder, get_git_branch,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.io import copy_model_files, load_config
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.training import NoamLR, setup_torch_training_env
|
||||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||
|
@ -524,86 +521,15 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
target_loss = train_avg_loss_dict['avg_loss']
|
||||
if c.run_eval:
|
||||
target_loss = eval_avg_loss_dict['avg_loss']
|
||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer,
|
||||
global_step, epoch, c.r,
|
||||
OUT_PATH)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--continue_path',
|
||||
type=str,
|
||||
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
||||
default='',
|
||||
required='--config_path' not in sys.argv)
|
||||
parser.add_argument(
|
||||
'--restore_path',
|
||||
type=str,
|
||||
help='Model file to be restored. Use to finetune a model.',
|
||||
default='')
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
type=str,
|
||||
help='Path to config file for training.',
|
||||
required='--continue_path' not in sys.argv
|
||||
)
|
||||
parser.add_argument('--debug',
|
||||
type=bool,
|
||||
default=False,
|
||||
help='Do not verify commit integrity to run training.')
|
||||
|
||||
# DISTRUBUTED
|
||||
parser.add_argument(
|
||||
'--rank',
|
||||
type=int,
|
||||
default=0,
|
||||
help='DISTRIBUTED: process rank for distributed training.')
|
||||
parser.add_argument('--group_id',
|
||||
type=str,
|
||||
default="",
|
||||
help='DISTRIBUTED: process group id.')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.continue_path != '':
|
||||
args.output_path = args.continue_path
|
||||
args.config_path = os.path.join(args.continue_path, 'config.json')
|
||||
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
|
||||
latest_model_file = max(list_of_files, key=os.path.getctime)
|
||||
args.restore_path = latest_model_file
|
||||
print(f" > Training continues for {args.restore_path}")
|
||||
|
||||
# setup output paths and read configs
|
||||
c = load_config(args.config_path)
|
||||
# check_config(c)
|
||||
check_config_tts(c)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
if c.mixed_precision:
|
||||
print(" > Mixed precision enabled.")
|
||||
|
||||
OUT_PATH = args.continue_path
|
||||
if args.continue_path == '':
|
||||
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
|
||||
|
||||
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
||||
|
||||
c_logger = ConsoleLogger()
|
||||
|
||||
if args.rank == 0:
|
||||
os.makedirs(AUDIO_PATH, exist_ok=True)
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
new_fields["restore_path"] = args.restore_path
|
||||
new_fields["github_branch"] = get_git_branch()
|
||||
copy_model_files(c, args.config_path, OUT_PATH, new_fields)
|
||||
os.chmod(AUDIO_PATH, 0o775)
|
||||
os.chmod(OUT_PATH, 0o775)
|
||||
|
||||
LOG_DIR = OUT_PATH
|
||||
tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS')
|
||||
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, model_type='tts')
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Trains Tacotron based TTS models."""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
@ -12,10 +10,11 @@ from random import randrange
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from TTS.utils.arguments import parse_arguments, process_args
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||
from TTS.tts.layers.losses import TacotronLoss
|
||||
from TTS.tts.utils.generic_utils import check_config_tts, setup_model
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.speakers import parse_speakers
|
||||
|
@ -23,15 +22,11 @@ from TTS.tts.utils.synthesis import synthesis
|
|||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce,
|
||||
init_distributed, reduce_tensor)
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
create_experiment_folder, get_git_branch,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.io import copy_model_files, load_config
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.training import (NoamLR, adam_weight_decay, check_update,
|
||||
gradual_training_scheduler, set_weight_decay,
|
||||
setup_torch_training_env)
|
||||
|
@ -61,7 +56,11 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
|
|||
phoneme_language=c.phoneme_language,
|
||||
enable_eos_bos=c.enable_eos_bos_chars,
|
||||
verbose=verbose,
|
||||
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
|
||||
speaker_mapping=(speaker_mapping if (
|
||||
c.use_speaker_embedding
|
||||
and c.use_external_speaker_embedding_file
|
||||
) else None)
|
||||
)
|
||||
|
||||
if c.use_phonemes and c.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
|
@ -491,7 +490,6 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
return keep_avg.avg_values
|
||||
|
||||
|
||||
# FIXME: move args definition/parsing inside of main?
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
||||
|
@ -636,84 +634,14 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
epoch,
|
||||
c.r,
|
||||
OUT_PATH,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None)
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--continue_path',
|
||||
type=str,
|
||||
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
||||
default='',
|
||||
required='--config_path' not in sys.argv)
|
||||
parser.add_argument(
|
||||
'--restore_path',
|
||||
type=str,
|
||||
help='Model file to be restored. Use to finetune a model.',
|
||||
default='')
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
type=str,
|
||||
help='Path to config file for training.',
|
||||
required='--continue_path' not in sys.argv
|
||||
)
|
||||
parser.add_argument('--debug',
|
||||
type=bool,
|
||||
default=False,
|
||||
help='Do not verify commit integrity to run training.')
|
||||
|
||||
# DISTRUBUTED
|
||||
parser.add_argument(
|
||||
'--rank',
|
||||
type=int,
|
||||
default=0,
|
||||
help='DISTRIBUTED: process rank for distributed training.')
|
||||
parser.add_argument('--group_id',
|
||||
type=str,
|
||||
default="",
|
||||
help='DISTRIBUTED: process group id.')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.continue_path != '':
|
||||
print(f" > Training continues for {args.continue_path}")
|
||||
args.output_path = args.continue_path
|
||||
args.config_path = os.path.join(args.continue_path, 'config.json')
|
||||
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
|
||||
latest_model_file = max(list_of_files, key=os.path.getctime)
|
||||
args.restore_path = latest_model_file
|
||||
|
||||
# setup output paths and read configs
|
||||
c = load_config(args.config_path)
|
||||
check_config_tts(c)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
if c.mixed_precision:
|
||||
print(" > Mixed precision mode is ON")
|
||||
|
||||
OUT_PATH = args.continue_path
|
||||
if args.continue_path == '':
|
||||
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
|
||||
|
||||
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
||||
|
||||
c_logger = ConsoleLogger()
|
||||
|
||||
if args.rank == 0:
|
||||
os.makedirs(AUDIO_PATH, exist_ok=True)
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
new_fields["restore_path"] = args.restore_path
|
||||
new_fields["github_branch"] = get_git_branch()
|
||||
copy_model_files(c, args.config_path, OUT_PATH, new_fields)
|
||||
os.chmod(AUDIO_PATH, 0o775)
|
||||
os.chmod(OUT_PATH, 0o775)
|
||||
|
||||
LOG_DIR = OUT_PATH
|
||||
tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS')
|
||||
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, model_type='tacotron')
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import argparse
|
||||
import glob
|
||||
#!/usr/bin/env python3
|
||||
"""Trains GAN based vocoder model."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
@ -8,14 +9,13 @@ from inspect import signature
|
|||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from TTS.utils.arguments import parse_arguments, process_args
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
create_experiment_folder, get_git_branch,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.io import copy_model_files, load_config
|
||||
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
|
||||
from TTS.utils.training import setup_torch_training_env
|
||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
|
@ -439,7 +439,6 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
|
|||
return keep_avg.avg_values
|
||||
|
||||
|
||||
# FIXME: move args definition/parsing inside of main?
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global train_data, eval_data
|
||||
|
@ -506,7 +505,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
|
||||
scheduler_disc.optimizer = optimizer_disc
|
||||
except RuntimeError:
|
||||
# retore only matching layers.
|
||||
# restore only matching layers.
|
||||
print(" > Partial model initialization...")
|
||||
model_dict = model_gen.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||
|
@ -556,7 +555,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
model_disc, criterion_disc, optimizer_disc,
|
||||
scheduler_gen, scheduler_disc, ap, global_step,
|
||||
epoch)
|
||||
eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap,
|
||||
eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc,
|
||||
criterion_disc, ap,
|
||||
global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = eval_avg_loss_dict[c.target_loss]
|
||||
|
@ -575,78 +575,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--continue_path',
|
||||
type=str,
|
||||
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
||||
default='',
|
||||
required='--config_path' not in sys.argv)
|
||||
parser.add_argument(
|
||||
'--restore_path',
|
||||
type=str,
|
||||
help='Model file to be restored. Use to finetune a model.',
|
||||
default='')
|
||||
parser.add_argument('--config_path',
|
||||
type=str,
|
||||
help='Path to config file for training.',
|
||||
required='--continue_path' not in sys.argv)
|
||||
parser.add_argument('--debug',
|
||||
type=bool,
|
||||
default=False,
|
||||
help='Do not verify commit integrity to run training.')
|
||||
|
||||
# DISTRUBUTED
|
||||
parser.add_argument(
|
||||
'--rank',
|
||||
type=int,
|
||||
default=0,
|
||||
help='DISTRIBUTED: process rank for distributed training.')
|
||||
parser.add_argument('--group_id',
|
||||
type=str,
|
||||
default="",
|
||||
help='DISTRIBUTED: process group id.')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.continue_path != '':
|
||||
args.output_path = args.continue_path
|
||||
args.config_path = os.path.join(args.continue_path, 'config.json')
|
||||
list_of_files = glob.glob(
|
||||
args.continue_path +
|
||||
"/*.pth.tar") # * means all if need specific format then *.csv
|
||||
latest_model_file = max(list_of_files, key=os.path.getctime)
|
||||
args.restore_path = latest_model_file
|
||||
print(f" > Training continues for {args.restore_path}")
|
||||
|
||||
# setup output paths and read configs
|
||||
c = load_config(args.config_path)
|
||||
# check_config(c)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
OUT_PATH = args.continue_path
|
||||
if args.continue_path == '':
|
||||
OUT_PATH = create_experiment_folder(c.output_path, c.run_name,
|
||||
args.debug)
|
||||
|
||||
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
||||
|
||||
c_logger = ConsoleLogger()
|
||||
|
||||
if args.rank == 0:
|
||||
os.makedirs(AUDIO_PATH, exist_ok=True)
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
new_fields["restore_path"] = args.restore_path
|
||||
new_fields["github_branch"] = get_git_branch()
|
||||
copy_model_files(c, args.config_path, OUT_PATH, new_fields)
|
||||
os.chmod(AUDIO_PATH, 0o775)
|
||||
os.chmod(OUT_PATH, 0o775)
|
||||
|
||||
LOG_DIR = OUT_PATH
|
||||
tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER')
|
||||
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, model_type='gan')
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import argparse
|
||||
import glob
|
||||
#!/usr/bin/env python3
|
||||
"""Trains WaveGrad vocoder models."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
@ -12,14 +13,11 @@ from torch.nn.parallel import DistributedDataParallel as DDP_th
|
|||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from TTS.utils.arguments import parse_arguments, process_args
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.distribute import init_distributed
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
create_experiment_folder, get_git_branch,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.io import copy_model_files, load_config
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.training import setup_torch_training_env
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
||||
|
@ -54,6 +52,7 @@ def setup_loader(ap, is_val=False, verbose=False):
|
|||
if is_val else c.num_loader_workers,
|
||||
pin_memory=False)
|
||||
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
|
@ -195,18 +194,19 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step,
|
|||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model,
|
||||
optimizer,
|
||||
scheduler,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=loss_dict,
|
||||
scaler=scaler.state_dict()
|
||||
if c.mixed_precision else None)
|
||||
save_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
scheduler,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=loss_dict,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
|
@ -247,6 +247,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
else:
|
||||
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
||||
|
||||
|
||||
# forward pass
|
||||
noise_hat = model(x_noisy, m, noise_scale)
|
||||
|
||||
|
@ -254,6 +255,7 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
loss = criterion(noise, noise_hat)
|
||||
loss_wavegrad_dict = {'wavegrad_loss': loss}
|
||||
|
||||
|
||||
loss_dict = dict()
|
||||
for key, value in loss_wavegrad_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
|
@ -415,87 +417,14 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=eval_avg_loss_dict,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None)
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--continue_path',
|
||||
type=str,
|
||||
help=
|
||||
'Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
||||
default='',
|
||||
required='--config_path' not in sys.argv)
|
||||
parser.add_argument(
|
||||
'--restore_path',
|
||||
type=str,
|
||||
help='Model file to be restored. Use to finetune a model.',
|
||||
default='')
|
||||
parser.add_argument('--config_path',
|
||||
type=str,
|
||||
help='Path to config file for training.',
|
||||
required='--continue_path' not in sys.argv)
|
||||
parser.add_argument('--debug',
|
||||
type=bool,
|
||||
default=False,
|
||||
help='Do not verify commit integrity to run training.')
|
||||
|
||||
# DISTRUBUTED
|
||||
parser.add_argument(
|
||||
'--rank',
|
||||
type=int,
|
||||
default=0,
|
||||
help='DISTRIBUTED: process rank for distributed training.')
|
||||
parser.add_argument('--group_id',
|
||||
type=str,
|
||||
default="",
|
||||
help='DISTRIBUTED: process group id.')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.continue_path != '':
|
||||
args.output_path = args.continue_path
|
||||
args.config_path = os.path.join(args.continue_path, 'config.json')
|
||||
list_of_files = glob.glob(
|
||||
args.continue_path +
|
||||
"/*.pth.tar") # * means all if need specific format then *.csv
|
||||
latest_model_file = max(list_of_files, key=os.path.getctime)
|
||||
args.restore_path = latest_model_file
|
||||
print(f" > Training continues for {args.restore_path}")
|
||||
|
||||
# setup output paths and read configs
|
||||
c = load_config(args.config_path)
|
||||
# check_config(c)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
# DISTRIBUTED
|
||||
if c.mixed_precision:
|
||||
print(" > Mixed precision is enabled")
|
||||
|
||||
OUT_PATH = args.continue_path
|
||||
if args.continue_path == '':
|
||||
OUT_PATH = create_experiment_folder(c.output_path, c.run_name,
|
||||
args.debug)
|
||||
|
||||
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
||||
|
||||
c_logger = ConsoleLogger()
|
||||
|
||||
if args.rank == 0:
|
||||
os.makedirs(AUDIO_PATH, exist_ok=True)
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
new_fields["restore_path"] = args.restore_path
|
||||
new_fields["github_branch"] = get_git_branch()
|
||||
copy_model_files(c, args.config_path, OUT_PATH, new_fields)
|
||||
os.chmod(AUDIO_PATH, 0o775)
|
||||
os.chmod(OUT_PATH, 0o775)
|
||||
|
||||
LOG_DIR = OUT_PATH
|
||||
tb_logger = TensorboardLogger(LOG_DIR, model_name='VOCODER')
|
||||
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, model_type='wavegrad')
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import argparse
|
||||
#!/usr/bin/env python3
|
||||
"""Train WaveRNN vocoder model."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import time
|
||||
import glob
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
@ -11,18 +12,14 @@ from torch.utils.data import DataLoader
|
|||
|
||||
# from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.utils.arguments import parse_arguments, process_args
|
||||
from TTS.tts.utils.visual import plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.io import copy_model_files, load_config
|
||||
from TTS.utils.training import setup_torch_training_env
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.generic_utils import (
|
||||
KeepAverage,
|
||||
count_parameters,
|
||||
create_experiment_folder,
|
||||
get_git_branch,
|
||||
remove_experiment_folder,
|
||||
set_init_dict,
|
||||
)
|
||||
|
@ -181,18 +178,19 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
|
|||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model,
|
||||
optimizer,
|
||||
scheduler,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=loss_dict,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None
|
||||
)
|
||||
save_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
scheduler,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=loss_dict,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None
|
||||
)
|
||||
|
||||
# synthesize a full voice
|
||||
rand_idx = random.randrange(0, len(train_data))
|
||||
|
@ -448,87 +446,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--continue_path",
|
||||
type=str,
|
||||
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
||||
default="",
|
||||
required="--config_path" not in sys.argv,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--restore_path",
|
||||
type=str,
|
||||
help="Model file to be restored. Use to finetune a model.",
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_path",
|
||||
type=str,
|
||||
help="Path to config file for training.",
|
||||
required="--continue_path" not in sys.argv,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Do not verify commit integrity to run training.",
|
||||
)
|
||||
|
||||
# DISTRUBUTED
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
type=int,
|
||||
default=0,
|
||||
help="DISTRIBUTED: process rank for distributed training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--group_id", type=str, default="", help="DISTRIBUTED: process group id."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.continue_path != "":
|
||||
args.output_path = args.continue_path
|
||||
args.config_path = os.path.join(args.continue_path, "config.json")
|
||||
list_of_files = glob.glob(
|
||||
args.continue_path + "/*.pth.tar"
|
||||
) # * means all if need specific format then *.csv
|
||||
latest_model_file = max(list_of_files, key=os.path.getctime)
|
||||
args.restore_path = latest_model_file
|
||||
print(f" > Training continues for {args.restore_path}")
|
||||
|
||||
# setup output paths and read configs
|
||||
c = load_config(args.config_path)
|
||||
# check_config(c)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
OUT_PATH = args.continue_path
|
||||
if args.continue_path == "":
|
||||
OUT_PATH = create_experiment_folder(
|
||||
c.output_path, c.run_name, args.debug
|
||||
)
|
||||
|
||||
AUDIO_PATH = os.path.join(OUT_PATH, "test_audios")
|
||||
|
||||
c_logger = ConsoleLogger()
|
||||
|
||||
if args.rank == 0:
|
||||
os.makedirs(AUDIO_PATH, exist_ok=True)
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
new_fields["restore_path"] = args.restore_path
|
||||
new_fields["github_branch"] = get_git_branch()
|
||||
copy_model_files(
|
||||
c, args.config_path, OUT_PATH, new_fields
|
||||
)
|
||||
os.chmod(AUDIO_PATH, 0o775)
|
||||
os.chmod(OUT_PATH, 0o775)
|
||||
|
||||
LOG_DIR = OUT_PATH
|
||||
tb_logger = TensorboardLogger(LOG_DIR, model_name="VOCODER")
|
||||
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text("model-description", c["run_description"], 0)
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, model_type='wavernn')
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -0,0 +1,207 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Argument parser for training scripts."""
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import glob
|
||||
import os
|
||||
|
||||
from TTS.utils.generic_utils import (
|
||||
create_experiment_folder, get_git_branch)
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.io import copy_model_files, load_config
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
|
||||
from TTS.tts.utils.generic_utils import check_config_tts
|
||||
|
||||
|
||||
def parse_arguments(argv):
|
||||
"""Parse command line arguments of training scripts.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
argv : list
|
||||
This is a list of input arguments as given by sys.argv
|
||||
|
||||
Returns
|
||||
-------
|
||||
argparse.Namespace
|
||||
Parsed arguments.
|
||||
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--continue_path",
|
||||
type=str,
|
||||
help=("Training output folder to continue training. Used to continue "
|
||||
"a training. If it is used, 'config_path' is ignored."),
|
||||
default="",
|
||||
required="--config_path" not in argv)
|
||||
parser.add_argument(
|
||||
"--restore_path",
|
||||
type=str,
|
||||
help="Model file to be restored. Use to finetune a model.",
|
||||
default="")
|
||||
parser.add_argument(
|
||||
"--config_path",
|
||||
type=str,
|
||||
help="Path to config file for training.",
|
||||
required="--continue_path" not in argv)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Do not verify commit integrity to run training.")
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
type=int,
|
||||
default=0,
|
||||
help="DISTRIBUTED: process rank for distributed training.")
|
||||
parser.add_argument(
|
||||
"--group_id",
|
||||
type=str,
|
||||
default="",
|
||||
help="DISTRIBUTED: process group id.")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def get_last_checkpoint(path):
|
||||
"""Get latest checkpoint from a list of filenames.
|
||||
|
||||
It is based on globbing for `*.pth.tar` and the RegEx
|
||||
`checkpoint_([0-9]+)`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : list
|
||||
Path to files to be compared.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If no checkpoint files are found.
|
||||
|
||||
Returns
|
||||
-------
|
||||
last_checkpoint : str
|
||||
Last checkpoint filename.
|
||||
|
||||
"""
|
||||
last_checkpoint_num = 0
|
||||
last_checkpoint = None
|
||||
filenames = glob.glob(
|
||||
os.path.join(path, "/*.pth.tar"))
|
||||
for filename in filenames:
|
||||
try:
|
||||
checkpoint_num = int(
|
||||
re.search(r"checkpoint_([0-9]+)", filename).groups()[0])
|
||||
if checkpoint_num > last_checkpoint_num:
|
||||
last_checkpoint_num = checkpoint_num
|
||||
last_checkpoint = filename
|
||||
except AttributeError: # if there's no match in the filename
|
||||
pass
|
||||
if last_checkpoint is None:
|
||||
raise ValueError(f"No checkpoints in {path}!")
|
||||
return last_checkpoint
|
||||
|
||||
|
||||
def process_args(args, model_type):
|
||||
"""Process parsed comand line arguments.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args : argparse.Namespace or dict like
|
||||
Parsed input arguments.
|
||||
model_type : str
|
||||
Model type used to check config parameters and setup the TensorBoard
|
||||
logger. One of:
|
||||
- tacotron
|
||||
- glow_tts
|
||||
- speedy_speech
|
||||
- gan
|
||||
- wavegrad
|
||||
- wavernn
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `model_type` is not one of implemented choices.
|
||||
|
||||
Returns
|
||||
-------
|
||||
c : TTS.utils.io.AttrDict
|
||||
Config paramaters.
|
||||
out_path : str
|
||||
Path to save models and logging.
|
||||
audio_path : str
|
||||
Path to save generated test audios.
|
||||
c_logger : TTS.utils.console_logger.ConsoleLogger
|
||||
Class that does logging to the console.
|
||||
tb_logger : TTS.utils.tensorboard.TensorboardLogger
|
||||
Class that does the TensorBoard loggind.
|
||||
|
||||
"""
|
||||
if args.continue_path != "":
|
||||
args.output_path = args.continue_path
|
||||
args.config_path = os.path.join(args.continue_path, "config.json")
|
||||
list_of_files = glob.glob(
|
||||
os.path.join(args.continue_path, "*.pth.tar")
|
||||
) # * means all if need specific format then *.csv
|
||||
args.restore_path = max(list_of_files, key=os.path.getctime)
|
||||
# checkpoint number based continuing
|
||||
# args.restore_path = get_last_checkpoint(args.continue_path)
|
||||
print(f" > Training continues for {args.restore_path}")
|
||||
|
||||
# setup output paths and read configs
|
||||
c = load_config(args.config_path)
|
||||
|
||||
if model_type in "tacotron glow_tts speedy_speech":
|
||||
model_class = "TTS"
|
||||
elif model_type in "gan wavegrad wavernn":
|
||||
model_class = "VOCODER"
|
||||
else:
|
||||
raise ValueError("model type {model_type} not recognized!")
|
||||
|
||||
if model_class == "TTS":
|
||||
check_config_tts(c)
|
||||
elif model_class == "VOCODER":
|
||||
print("Vocoder config checker not implemented, "
|
||||
"skipping ...")
|
||||
else:
|
||||
raise ValueError(f"model type {model_type} not recognized!")
|
||||
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
if model_type in "tacotron wavegrad wavernn" and c.mixed_precision:
|
||||
print(" > Mixed precision mode is ON")
|
||||
|
||||
out_path = args.continue_path
|
||||
if args.continue_path == "":
|
||||
out_path = create_experiment_folder(c.output_path, c.run_name,
|
||||
args.debug)
|
||||
|
||||
audio_path = os.path.join(out_path, "test_audios")
|
||||
|
||||
c_logger = ConsoleLogger()
|
||||
|
||||
if args.rank == 0:
|
||||
os.makedirs(audio_path, exist_ok=True)
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
new_fields["restore_path"] = args.restore_path
|
||||
new_fields["github_branch"] = get_git_branch()
|
||||
copy_model_files(c, args.config_path,
|
||||
out_path, new_fields)
|
||||
os.chmod(audio_path, 0o775)
|
||||
os.chmod(out_path, 0o775)
|
||||
|
||||
log_path = out_path
|
||||
|
||||
tb_logger = TensorboardLogger(log_path, model_name=model_class)
|
||||
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text("model-description", c["run_description"], 0)
|
||||
|
||||
return c, out_path, audio_path, c_logger, tb_logger
|
|
@ -21,7 +21,7 @@ def test_phoneme_to_sequence():
|
|||
text_hat_with_params = sequence_to_phoneme(sequence, tp=conf.characters)
|
||||
gt = "ɹiːsənt ɹɪsɜːtʃ æt hɑːɹvɚd hɐz ʃoʊn mɛdᵻteɪɾɪŋ fɔːɹ æz lɪɾəl æz eɪt wiːks kæn æktʃuːəli ɪnkɹiːs, ðə ɡɹeɪ mæɾɚɹ ɪnðə pɑːɹts ʌvðə bɹeɪn ɹɪspɑːnsəbəl fɔːɹ ɪmoʊʃənəl ɹɛɡjuːleɪʃən ænd lɜːnɪŋ!"
|
||||
assert text_hat == text_hat_with_params == gt
|
||||
|
||||
|
||||
# multiple punctuations
|
||||
text = "Be a voice, not an! echo?"
|
||||
sequence = phoneme_to_sequence(text, text_cleaner, lang)
|
||||
|
|
Loading…
Reference in New Issue