make style

pull/602/head
Eren Gölge 2021-05-28 13:37:08 +02:00
parent d376647ca0
commit ca787be193
13 changed files with 130 additions and 130 deletions

View File

@ -8,10 +8,10 @@ import numpy as np
import tensorflow as tf
import torch
from TTS.tts.models import setup_model
from TTS.tts.tf.models.tacotron2 import Tacotron2
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
from TTS.tts.tf.utils.generic_utils import save_checkpoint
from TTS.tts.models import setup_model
from TTS.tts.utils.text.symbols import phonemes, symbols
from TTS.utils.io import load_config

View File

@ -1,9 +1,10 @@
import os
import sys
import traceback
from TTS.trainer import TrainerTTS
from TTS.utils.arguments import init_training
from TTS.utils.generic_utils import remove_experiment_folder
from TTS.trainer import TrainerTTS
def main():

View File

@ -4,21 +4,19 @@ import importlib
import logging
import os
import time
from argparse import Namespace
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Union
import torch
from coqpit import Coqpit
from dataclasses import dataclass, field
from typing import Tuple, Dict, List, Union
from argparse import Namespace
# DISTRIBUTED
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
from TTS.tts.datasets import TTSDataset, load_meta_data
from TTS.tts.layers import setup_loss
from TTS.tts.models import setup_model
@ -30,49 +28,48 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import init_distributed
from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
from TTS.utils.training import check_update, setup_torch_training_env
@dataclass
class TrainingArgs(Coqpit):
continue_path: str = field(
default='',
default="",
metadata={
'help':
'Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder.'
})
"help": "Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder."
},
)
restore_path: str = field(
default='',
default="",
metadata={
'help':
'Path to a model checkpoit. Restore the model with the given checkpoint and start a new training.'
})
"help": "Path to a model checkpoit. Restore the model with the given checkpoint and start a new training."
},
)
best_path: str = field(
default='',
default="",
metadata={
'help':
"Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used"
})
config_path: str = field(
default='', metadata={'help': 'Path to the configuration file.'})
rank: int = field(
default=0, metadata={'help': 'Process rank in distributed training.'})
group_id: str = field(
default='',
metadata={'help': 'Process group id in distributed training.'})
"help": "Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used"
},
)
config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
# pylint: disable=import-outside-toplevel, too-many-public-methods
class TrainerTTS:
use_cuda, num_gpus = setup_torch_training_env(True, False)
def __init__(self,
args: Union[Coqpit, Namespace],
config: Coqpit,
c_logger: ConsoleLogger,
tb_logger: TensorboardLogger,
model: nn.Module = None,
output_path: str = None) -> None:
def __init__(
self,
args: Union[Coqpit, Namespace],
config: Coqpit,
c_logger: ConsoleLogger,
tb_logger: TensorboardLogger,
model: nn.Module = None,
output_path: str = None,
) -> None:
self.args = args
self.config = config
self.c_logger = c_logger
@ -90,8 +87,7 @@ class TrainerTTS:
self.keep_avg_train = None
self.keep_avg_eval = None
log_file = os.path.join(self.output_path,
f"trainer_{args.rank}_log.txt")
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
self._setup_logger_config(log_file)
# model, audio processor, datasets, loss
@ -106,16 +102,19 @@ class TrainerTTS:
# default speaker manager
self.speaker_manager = self.get_speaker_manager(
self.config, args.restore_path, self.config.output_path, self.data_train)
self.config, args.restore_path, self.config.output_path, self.data_train
)
# init TTS model
if model is not None:
self.model = model
else:
self.model = self.get_model(
len(self.model_characters), self.speaker_manager.num_speakers,
self.config, self.speaker_manager.x_vector_dim
if self.speaker_manager.x_vectors else None)
len(self.model_characters),
self.speaker_manager.num_speakers,
self.config,
self.speaker_manager.x_vector_dim if self.speaker_manager.x_vectors else None,
)
# setup criterion
self.criterion = self.get_criterion(self.config)
@ -126,13 +125,16 @@ class TrainerTTS:
# DISTRUBUTED
if self.num_gpus > 1:
init_distributed(args.rank, self.num_gpus, args.group_id,
self.config.distributed["backend"],
self.config.distributed["url"])
init_distributed(
args.rank,
self.num_gpus,
args.group_id,
self.config.distributed["backend"],
self.config.distributed["url"],
)
# scalers for mixed precision training
self.scaler = torch.cuda.amp.GradScaler(
) if self.config.mixed_precision and self.use_cuda else None
self.scaler = torch.cuda.amp.GradScaler() if self.config.mixed_precision and self.use_cuda else None
# setup optimizer
self.optimizer = self.get_optimizer(self.model, self.config)
@ -154,8 +156,7 @@ class TrainerTTS:
print("\n > Model has {} parameters".format(num_params))
@staticmethod
def get_model(num_chars: int, num_speakers: int, config: Coqpit,
x_vector_dim: int) -> nn.Module:
def get_model(num_chars: int, num_speakers: int, config: Coqpit, x_vector_dim: int) -> nn.Module:
model = setup_model(num_chars, num_speakers, config, x_vector_dim)
return model
@ -182,26 +183,32 @@ class TrainerTTS:
return model_characters
@staticmethod
def get_speaker_manager(config: Coqpit,
restore_path: str = "",
out_path: str = "",
data_train: List = []) -> SpeakerManager:
def get_speaker_manager(
config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = []
) -> SpeakerManager:
speaker_manager = SpeakerManager()
if restore_path:
speakers_file = os.path.join(os.path.dirname(restore_path),
"speaker.json")
speakers_file = os.path.join(os.path.dirname(restore_path), "speaker.json")
if not os.path.exists(speakers_file):
print(
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
)
speakers_file = config.external_speaker_embedding_file
if config.use_external_speaker_embedding_file:
speaker_manager.load_x_vectors_file(speakers_file)
else:
speaker_manager.load_ids_file(speakers_file)
elif config.use_external_speaker_embedding_file and config.external_speaker_embedding_file:
speaker_manager.load_x_vectors_file(config.external_speaker_embedding_file)
else:
speaker_manager.parse_speakers_from_items(data_train)
file_path = os.path.join(out_path, "speakers.json")
speaker_manager.save_ids_file(file_path)
return speaker_manager
@staticmethod
def get_scheduler(config: Coqpit,
optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
def get_scheduler(config: Coqpit, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
lr_scheduler = config.lr_scheduler
lr_scheduler_params = config.lr_scheduler_params
if lr_scheduler is None:
@ -224,7 +231,7 @@ class TrainerTTS:
restore_path: str,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scaler: torch.cuda.amp.GradScaler = None
scaler: torch.cuda.amp.GradScaler = None,
) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]:
print(" > Restoring from %s ..." % os.path.basename(restore_path))
checkpoint = torch.load(restore_path)
@ -245,13 +252,21 @@ class TrainerTTS:
for group in optimizer.param_groups:
group["lr"] = self.config.lr
print(" > Model restored from step %d" % checkpoint["step"], )
print(
" > Model restored from step %d" % checkpoint["step"],
)
restore_step = checkpoint["step"]
return model, optimizer, scaler, restore_step
def _get_loader(self, r: int, ap: AudioProcessor, is_eval: bool,
data_items: List, verbose: bool,
speaker_mapping: Union[Dict, List]) -> DataLoader:
def _get_loader(
self,
r: int,
ap: AudioProcessor,
is_eval: bool,
data_items: List,
verbose: bool,
speaker_mapping: Union[Dict, List],
) -> DataLoader:
if is_eval and not self.config.run_eval:
loader = None
else:
@ -295,17 +310,15 @@ class TrainerTTS:
)
return loader
def get_train_dataloader(self, r: int, ap: AudioProcessor,
data_items: List, verbose: bool,
speaker_mapping: Union[List, Dict]) -> DataLoader:
return self._get_loader(r, ap, False, data_items, verbose,
speaker_mapping)
def get_train_dataloader(
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_mapping: Union[List, Dict]
) -> DataLoader:
return self._get_loader(r, ap, False, data_items, verbose, speaker_mapping)
def get_eval_dataloder(self, r: int, ap: AudioProcessor, data_items: List,
verbose: bool,
speaker_mapping: Union[List, Dict]) -> DataLoader:
return self._get_loader(r, ap, True, data_items, verbose,
speaker_mapping)
def get_eval_dataloder(
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_mapping: Union[List, Dict]
) -> DataLoader:
return self._get_loader(r, ap, True, data_items, verbose, speaker_mapping)
def format_batch(self, batch: List) -> Dict:
# setup input batch
@ -390,8 +403,7 @@ class TrainerTTS:
"item_idx": item_idx,
}
def train_step(self, batch: Dict, batch_n_steps: int, step: int,
loader_start_time: float) -> Tuple[Dict, Dict]:
def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
self.on_train_step_start()
step_start_time = time.time()
@ -560,7 +572,9 @@ class TrainerTTS:
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate)
def test_run(self, ) -> None:
def test_run(
self,
) -> None:
print(" | > Synthesizing test sentences.")
test_audios = {}
test_figures = {}
@ -581,28 +595,26 @@ class TrainerTTS:
do_trim_silence=False,
).values()
file_path = os.path.join(self.output_audio_path,
str(self.total_steps_done))
file_path = os.path.join(self.output_audio_path, str(self.total_steps_done))
os.makedirs(file_path, exist_ok=True)
file_path = os.path.join(file_path,
"TestSentence_{}.wav".format(idx))
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
self.ap.save_wav(wav, file_path)
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios,
self.config.audio["sample_rate"])
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios, self.config.audio["sample_rate"])
self.tb_logger.tb_test_figures(self.total_steps_done, test_figures)
def _get_cond_inputs(self) -> Dict:
# setup speaker_id
speaker_id = 0 if self.config.use_speaker_embedding else None
# setup x_vector
x_vector = (self.speaker_manager.get_x_vectors_by_speaker(
self.speaker_manager.speaker_ids[0])
if self.config.use_external_speaker_embedding_file
and self.config.use_speaker_embedding else None)
x_vector = (
self.speaker_manager.get_x_vectors_by_speaker(self.speaker_manager.speaker_ids[0])
if self.config.use_external_speaker_embedding_file and self.config.use_speaker_embedding
else None
)
# setup style_mel
if self.config.has("gst_style_input"):
style_wav = self.config.gst_style_input
@ -611,40 +623,29 @@ class TrainerTTS:
if style_wav is None and "use_gst" in self.config and self.config.use_gst:
# inicialize GST with zero dict.
style_wav = {}
print(
"WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!"
)
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
for i in range(self.config.gst["gst_num_style_tokens"]):
style_wav[str(i)] = 0
cond_inputs = {
"speaker_id": speaker_id,
"style_wav": style_wav,
"x_vector": x_vector
}
cond_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "x_vector": x_vector}
return cond_inputs
def fit(self) -> None:
if self.restore_step != 0 or self.args.best_path:
print(" > Restoring best loss from "
f"{os.path.basename(self.args.best_path)} ...")
self.best_loss = torch.load(self.args.best_path,
map_location="cpu")["model_loss"]
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"]
print(f" > Starting with loaded last best loss {self.best_loss}.")
# define data loaders
self.train_loader = self.get_train_dataloader(
self.config.r,
self.ap,
self.data_train,
verbose=True,
speaker_mapping=self.speaker_manager.speaker_ids)
self.eval_loader = (self.get_eval_dataloder(
self.config.r,
self.ap,
self.data_train,
verbose=True,
speaker_mapping=self.speaker_manager.speaker_ids)
if self.config.run_eval else None)
self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids
)
self.eval_loader = (
self.get_eval_dataloder(
self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids
)
if self.config.run_eval
else None
)
self.total_steps_done = self.restore_step
@ -667,8 +668,7 @@ class TrainerTTS:
def save_best_model(self) -> None:
self.best_loss = save_best_model(
self.keep_avg_eval["avg_loss"]
if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
self.best_loss,
self.model,
self.optimizer,
@ -685,10 +685,8 @@ class TrainerTTS:
@staticmethod
def _setup_logger_config(log_file: str) -> None:
logging.basicConfig(
level=logging.INFO,
format="",
handlers=[logging.FileHandler(log_file),
logging.StreamHandler()])
level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
)
def on_epoch_start(self) -> None: # pylint: disable=no-self-use
if hasattr(self.model, "on_epoch_start"):

View File

@ -1,9 +1,11 @@
import sys
import numpy as np
from collections import Counter
from pathlib import Path
from TTS.tts.datasets.TTSDataset import TTSDataset
import numpy as np
from TTS.tts.datasets.formatters import *
from TTS.tts.datasets.TTSDataset import TTSDataset
####################
# UTILITIES

View File

@ -7,7 +7,6 @@ from typing import List
from tqdm import tqdm
########################
# DATASETS
########################

View File

@ -4,13 +4,13 @@ import torch.nn as nn
from TTS.tts.layers.align_tts.mdn import MDNBlock
from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
class AlignTTS(nn.Module):

View File

@ -6,11 +6,11 @@ from torch.nn import functional as F
from TTS.tts.layers.glow_tts.decoder import Decoder
from TTS.tts.layers.glow_tts.encoder import Encoder
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.utils.data import sequence_mask
class GlowTTS(nn.Module):

View File

@ -3,13 +3,13 @@ from torch import nn
from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.layers.glow_tts.monotonic_align import generate_path
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
class SpeedySpeech(nn.Module):

View File

@ -2,11 +2,11 @@
import torch
from torch import nn
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG
from TTS.tts.models.tacotron_abstract import TacotronAbstract
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
class Tacotron(TacotronAbstract):

View File

@ -3,11 +3,11 @@ import numpy as np
import torch
from torch import nn
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.tts.layers.tacotron.gst_layers import GST
from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
from TTS.tts.models.tacotron_abstract import TacotronAbstract
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
class Tacotron2(TacotronAbstract):

View File

@ -1,5 +1,5 @@
import torch
import numpy as np
import torch
def _pad_data(x, length):

View File

@ -1,7 +1,7 @@
import json
import os
import random
from typing import Union, List, Any
from typing import Any, List, Union
import numpy as np
import torch

View File

@ -11,9 +11,9 @@ import torch
from TTS.config import load_config
from TTS.tts.utils.text.symbols import parse_symbols
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
from TTS.utils.io import copy_model_files
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
def init_arguments(argv):