mirror of https://github.com/coqui-ai/TTS.git
make style
parent
d376647ca0
commit
ca787be193
|
@ -8,10 +8,10 @@ import numpy as np
|
|||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.tf.models.tacotron2 import Tacotron2
|
||||
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
|
||||
from TTS.tts.tf.utils.generic_utils import save_checkpoint
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.utils.text.symbols import phonemes, symbols
|
||||
from TTS.utils.io import load_config
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from TTS.trainer import TrainerTTS
|
||||
from TTS.utils.arguments import init_training
|
||||
from TTS.utils.generic_utils import remove_experiment_folder
|
||||
from TTS.trainer import TrainerTTS
|
||||
|
||||
|
||||
def main():
|
||||
|
|
218
TTS/trainer.py
218
TTS/trainer.py
|
@ -4,21 +4,19 @@ import importlib
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
from argparse import Namespace
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from coqpit import Coqpit
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Tuple, Dict, List, Union
|
||||
|
||||
from argparse import Namespace
|
||||
# DISTRIBUTED
|
||||
from torch import nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
||||
from TTS.tts.datasets import TTSDataset, load_meta_data
|
||||
from TTS.tts.layers import setup_loss
|
||||
from TTS.tts.models import setup_model
|
||||
|
@ -30,49 +28,48 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
|||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.distribute import init_distributed
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, set_init_dict
|
||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
||||
from TTS.utils.training import check_update, setup_torch_training_env
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArgs(Coqpit):
|
||||
continue_path: str = field(
|
||||
default='',
|
||||
default="",
|
||||
metadata={
|
||||
'help':
|
||||
'Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder.'
|
||||
})
|
||||
"help": "Path to a training folder to continue training. Restore the model from the last checkpoint and continue training under the same folder."
|
||||
},
|
||||
)
|
||||
restore_path: str = field(
|
||||
default='',
|
||||
default="",
|
||||
metadata={
|
||||
'help':
|
||||
'Path to a model checkpoit. Restore the model with the given checkpoint and start a new training.'
|
||||
})
|
||||
"help": "Path to a model checkpoit. Restore the model with the given checkpoint and start a new training."
|
||||
},
|
||||
)
|
||||
best_path: str = field(
|
||||
default='',
|
||||
default="",
|
||||
metadata={
|
||||
'help':
|
||||
"Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used"
|
||||
})
|
||||
config_path: str = field(
|
||||
default='', metadata={'help': 'Path to the configuration file.'})
|
||||
rank: int = field(
|
||||
default=0, metadata={'help': 'Process rank in distributed training.'})
|
||||
group_id: str = field(
|
||||
default='',
|
||||
metadata={'help': 'Process group id in distributed training.'})
|
||||
"help": "Best model file to be used for extracting best loss. If not specified, the latest best model in continue path is used"
|
||||
},
|
||||
)
|
||||
config_path: str = field(default="", metadata={"help": "Path to the configuration file."})
|
||||
rank: int = field(default=0, metadata={"help": "Process rank in distributed training."})
|
||||
group_id: str = field(default="", metadata={"help": "Process group id in distributed training."})
|
||||
|
||||
|
||||
# pylint: disable=import-outside-toplevel, too-many-public-methods
|
||||
class TrainerTTS:
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||
|
||||
def __init__(self,
|
||||
args: Union[Coqpit, Namespace],
|
||||
config: Coqpit,
|
||||
c_logger: ConsoleLogger,
|
||||
tb_logger: TensorboardLogger,
|
||||
model: nn.Module = None,
|
||||
output_path: str = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
args: Union[Coqpit, Namespace],
|
||||
config: Coqpit,
|
||||
c_logger: ConsoleLogger,
|
||||
tb_logger: TensorboardLogger,
|
||||
model: nn.Module = None,
|
||||
output_path: str = None,
|
||||
) -> None:
|
||||
self.args = args
|
||||
self.config = config
|
||||
self.c_logger = c_logger
|
||||
|
@ -90,8 +87,7 @@ class TrainerTTS:
|
|||
self.keep_avg_train = None
|
||||
self.keep_avg_eval = None
|
||||
|
||||
log_file = os.path.join(self.output_path,
|
||||
f"trainer_{args.rank}_log.txt")
|
||||
log_file = os.path.join(self.output_path, f"trainer_{args.rank}_log.txt")
|
||||
self._setup_logger_config(log_file)
|
||||
|
||||
# model, audio processor, datasets, loss
|
||||
|
@ -106,16 +102,19 @@ class TrainerTTS:
|
|||
|
||||
# default speaker manager
|
||||
self.speaker_manager = self.get_speaker_manager(
|
||||
self.config, args.restore_path, self.config.output_path, self.data_train)
|
||||
self.config, args.restore_path, self.config.output_path, self.data_train
|
||||
)
|
||||
|
||||
# init TTS model
|
||||
if model is not None:
|
||||
self.model = model
|
||||
else:
|
||||
self.model = self.get_model(
|
||||
len(self.model_characters), self.speaker_manager.num_speakers,
|
||||
self.config, self.speaker_manager.x_vector_dim
|
||||
if self.speaker_manager.x_vectors else None)
|
||||
len(self.model_characters),
|
||||
self.speaker_manager.num_speakers,
|
||||
self.config,
|
||||
self.speaker_manager.x_vector_dim if self.speaker_manager.x_vectors else None,
|
||||
)
|
||||
|
||||
# setup criterion
|
||||
self.criterion = self.get_criterion(self.config)
|
||||
|
@ -126,13 +125,16 @@ class TrainerTTS:
|
|||
|
||||
# DISTRUBUTED
|
||||
if self.num_gpus > 1:
|
||||
init_distributed(args.rank, self.num_gpus, args.group_id,
|
||||
self.config.distributed["backend"],
|
||||
self.config.distributed["url"])
|
||||
init_distributed(
|
||||
args.rank,
|
||||
self.num_gpus,
|
||||
args.group_id,
|
||||
self.config.distributed["backend"],
|
||||
self.config.distributed["url"],
|
||||
)
|
||||
|
||||
# scalers for mixed precision training
|
||||
self.scaler = torch.cuda.amp.GradScaler(
|
||||
) if self.config.mixed_precision and self.use_cuda else None
|
||||
self.scaler = torch.cuda.amp.GradScaler() if self.config.mixed_precision and self.use_cuda else None
|
||||
|
||||
# setup optimizer
|
||||
self.optimizer = self.get_optimizer(self.model, self.config)
|
||||
|
@ -154,8 +156,7 @@ class TrainerTTS:
|
|||
print("\n > Model has {} parameters".format(num_params))
|
||||
|
||||
@staticmethod
|
||||
def get_model(num_chars: int, num_speakers: int, config: Coqpit,
|
||||
x_vector_dim: int) -> nn.Module:
|
||||
def get_model(num_chars: int, num_speakers: int, config: Coqpit, x_vector_dim: int) -> nn.Module:
|
||||
model = setup_model(num_chars, num_speakers, config, x_vector_dim)
|
||||
return model
|
||||
|
||||
|
@ -182,26 +183,32 @@ class TrainerTTS:
|
|||
return model_characters
|
||||
|
||||
@staticmethod
|
||||
def get_speaker_manager(config: Coqpit,
|
||||
restore_path: str = "",
|
||||
out_path: str = "",
|
||||
data_train: List = []) -> SpeakerManager:
|
||||
def get_speaker_manager(
|
||||
config: Coqpit, restore_path: str = "", out_path: str = "", data_train: List = []
|
||||
) -> SpeakerManager:
|
||||
speaker_manager = SpeakerManager()
|
||||
if restore_path:
|
||||
speakers_file = os.path.join(os.path.dirname(restore_path),
|
||||
"speaker.json")
|
||||
speakers_file = os.path.join(os.path.dirname(restore_path), "speaker.json")
|
||||
if not os.path.exists(speakers_file):
|
||||
print(
|
||||
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
|
||||
)
|
||||
speakers_file = config.external_speaker_embedding_file
|
||||
|
||||
if config.use_external_speaker_embedding_file:
|
||||
speaker_manager.load_x_vectors_file(speakers_file)
|
||||
else:
|
||||
speaker_manager.load_ids_file(speakers_file)
|
||||
elif config.use_external_speaker_embedding_file and config.external_speaker_embedding_file:
|
||||
speaker_manager.load_x_vectors_file(config.external_speaker_embedding_file)
|
||||
else:
|
||||
speaker_manager.parse_speakers_from_items(data_train)
|
||||
file_path = os.path.join(out_path, "speakers.json")
|
||||
speaker_manager.save_ids_file(file_path)
|
||||
return speaker_manager
|
||||
|
||||
@staticmethod
|
||||
def get_scheduler(config: Coqpit,
|
||||
optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
|
||||
def get_scheduler(config: Coqpit, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
|
||||
lr_scheduler = config.lr_scheduler
|
||||
lr_scheduler_params = config.lr_scheduler_params
|
||||
if lr_scheduler is None:
|
||||
|
@ -224,7 +231,7 @@ class TrainerTTS:
|
|||
restore_path: str,
|
||||
model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scaler: torch.cuda.amp.GradScaler = None
|
||||
scaler: torch.cuda.amp.GradScaler = None,
|
||||
) -> Tuple[nn.Module, torch.optim.Optimizer, torch.cuda.amp.GradScaler, int]:
|
||||
print(" > Restoring from %s ..." % os.path.basename(restore_path))
|
||||
checkpoint = torch.load(restore_path)
|
||||
|
@ -245,13 +252,21 @@ class TrainerTTS:
|
|||
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = self.config.lr
|
||||
print(" > Model restored from step %d" % checkpoint["step"], )
|
||||
print(
|
||||
" > Model restored from step %d" % checkpoint["step"],
|
||||
)
|
||||
restore_step = checkpoint["step"]
|
||||
return model, optimizer, scaler, restore_step
|
||||
|
||||
def _get_loader(self, r: int, ap: AudioProcessor, is_eval: bool,
|
||||
data_items: List, verbose: bool,
|
||||
speaker_mapping: Union[Dict, List]) -> DataLoader:
|
||||
def _get_loader(
|
||||
self,
|
||||
r: int,
|
||||
ap: AudioProcessor,
|
||||
is_eval: bool,
|
||||
data_items: List,
|
||||
verbose: bool,
|
||||
speaker_mapping: Union[Dict, List],
|
||||
) -> DataLoader:
|
||||
if is_eval and not self.config.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
|
@ -295,17 +310,15 @@ class TrainerTTS:
|
|||
)
|
||||
return loader
|
||||
|
||||
def get_train_dataloader(self, r: int, ap: AudioProcessor,
|
||||
data_items: List, verbose: bool,
|
||||
speaker_mapping: Union[List, Dict]) -> DataLoader:
|
||||
return self._get_loader(r, ap, False, data_items, verbose,
|
||||
speaker_mapping)
|
||||
def get_train_dataloader(
|
||||
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_mapping: Union[List, Dict]
|
||||
) -> DataLoader:
|
||||
return self._get_loader(r, ap, False, data_items, verbose, speaker_mapping)
|
||||
|
||||
def get_eval_dataloder(self, r: int, ap: AudioProcessor, data_items: List,
|
||||
verbose: bool,
|
||||
speaker_mapping: Union[List, Dict]) -> DataLoader:
|
||||
return self._get_loader(r, ap, True, data_items, verbose,
|
||||
speaker_mapping)
|
||||
def get_eval_dataloder(
|
||||
self, r: int, ap: AudioProcessor, data_items: List, verbose: bool, speaker_mapping: Union[List, Dict]
|
||||
) -> DataLoader:
|
||||
return self._get_loader(r, ap, True, data_items, verbose, speaker_mapping)
|
||||
|
||||
def format_batch(self, batch: List) -> Dict:
|
||||
# setup input batch
|
||||
|
@ -390,8 +403,7 @@ class TrainerTTS:
|
|||
"item_idx": item_idx,
|
||||
}
|
||||
|
||||
def train_step(self, batch: Dict, batch_n_steps: int, step: int,
|
||||
loader_start_time: float) -> Tuple[Dict, Dict]:
|
||||
def train_step(self, batch: Dict, batch_n_steps: int, step: int, loader_start_time: float) -> Tuple[Dict, Dict]:
|
||||
self.on_train_step_start()
|
||||
step_start_time = time.time()
|
||||
|
||||
|
@ -560,7 +572,9 @@ class TrainerTTS:
|
|||
self.tb_logger.tb_eval_figures(self.total_steps_done, figures)
|
||||
self.tb_logger.tb_eval_audios(self.total_steps_done, {"EvalAudio": eval_audios}, self.ap.sample_rate)
|
||||
|
||||
def test_run(self, ) -> None:
|
||||
def test_run(
|
||||
self,
|
||||
) -> None:
|
||||
print(" | > Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
|
@ -581,28 +595,26 @@ class TrainerTTS:
|
|||
do_trim_silence=False,
|
||||
).values()
|
||||
|
||||
file_path = os.path.join(self.output_audio_path,
|
||||
str(self.total_steps_done))
|
||||
file_path = os.path.join(self.output_audio_path, str(self.total_steps_done))
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
file_path = os.path.join(file_path,
|
||||
"TestSentence_{}.wav".format(idx))
|
||||
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
|
||||
self.ap.save_wav(wav, file_path)
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
|
||||
|
||||
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios,
|
||||
self.config.audio["sample_rate"])
|
||||
self.tb_logger.tb_test_audios(self.total_steps_done, test_audios, self.config.audio["sample_rate"])
|
||||
self.tb_logger.tb_test_figures(self.total_steps_done, test_figures)
|
||||
|
||||
def _get_cond_inputs(self) -> Dict:
|
||||
# setup speaker_id
|
||||
speaker_id = 0 if self.config.use_speaker_embedding else None
|
||||
# setup x_vector
|
||||
x_vector = (self.speaker_manager.get_x_vectors_by_speaker(
|
||||
self.speaker_manager.speaker_ids[0])
|
||||
if self.config.use_external_speaker_embedding_file
|
||||
and self.config.use_speaker_embedding else None)
|
||||
x_vector = (
|
||||
self.speaker_manager.get_x_vectors_by_speaker(self.speaker_manager.speaker_ids[0])
|
||||
if self.config.use_external_speaker_embedding_file and self.config.use_speaker_embedding
|
||||
else None
|
||||
)
|
||||
# setup style_mel
|
||||
if self.config.has("gst_style_input"):
|
||||
style_wav = self.config.gst_style_input
|
||||
|
@ -611,40 +623,29 @@ class TrainerTTS:
|
|||
if style_wav is None and "use_gst" in self.config and self.config.use_gst:
|
||||
# inicialize GST with zero dict.
|
||||
style_wav = {}
|
||||
print(
|
||||
"WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!"
|
||||
)
|
||||
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
|
||||
for i in range(self.config.gst["gst_num_style_tokens"]):
|
||||
style_wav[str(i)] = 0
|
||||
cond_inputs = {
|
||||
"speaker_id": speaker_id,
|
||||
"style_wav": style_wav,
|
||||
"x_vector": x_vector
|
||||
}
|
||||
cond_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "x_vector": x_vector}
|
||||
return cond_inputs
|
||||
|
||||
def fit(self) -> None:
|
||||
if self.restore_step != 0 or self.args.best_path:
|
||||
print(" > Restoring best loss from "
|
||||
f"{os.path.basename(self.args.best_path)} ...")
|
||||
self.best_loss = torch.load(self.args.best_path,
|
||||
map_location="cpu")["model_loss"]
|
||||
print(" > Restoring best loss from " f"{os.path.basename(self.args.best_path)} ...")
|
||||
self.best_loss = torch.load(self.args.best_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {self.best_loss}.")
|
||||
|
||||
# define data loaders
|
||||
self.train_loader = self.get_train_dataloader(
|
||||
self.config.r,
|
||||
self.ap,
|
||||
self.data_train,
|
||||
verbose=True,
|
||||
speaker_mapping=self.speaker_manager.speaker_ids)
|
||||
self.eval_loader = (self.get_eval_dataloder(
|
||||
self.config.r,
|
||||
self.ap,
|
||||
self.data_train,
|
||||
verbose=True,
|
||||
speaker_mapping=self.speaker_manager.speaker_ids)
|
||||
if self.config.run_eval else None)
|
||||
self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids
|
||||
)
|
||||
self.eval_loader = (
|
||||
self.get_eval_dataloder(
|
||||
self.config.r, self.ap, self.data_train, verbose=True, speaker_mapping=self.speaker_manager.speaker_ids
|
||||
)
|
||||
if self.config.run_eval
|
||||
else None
|
||||
)
|
||||
|
||||
self.total_steps_done = self.restore_step
|
||||
|
||||
|
@ -667,8 +668,7 @@ class TrainerTTS:
|
|||
|
||||
def save_best_model(self) -> None:
|
||||
self.best_loss = save_best_model(
|
||||
self.keep_avg_eval["avg_loss"]
|
||||
if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
|
||||
self.keep_avg_eval["avg_loss"] if self.keep_avg_eval else self.keep_avg_train["avg_loss"],
|
||||
self.best_loss,
|
||||
self.model,
|
||||
self.optimizer,
|
||||
|
@ -685,10 +685,8 @@ class TrainerTTS:
|
|||
@staticmethod
|
||||
def _setup_logger_config(log_file: str) -> None:
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="",
|
||||
handlers=[logging.FileHandler(log_file),
|
||||
logging.StreamHandler()])
|
||||
level=logging.INFO, format="", handlers=[logging.FileHandler(log_file), logging.StreamHandler()]
|
||||
)
|
||||
|
||||
def on_epoch_start(self) -> None: # pylint: disable=no-self-use
|
||||
if hasattr(self.model, "on_epoch_start"):
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import sys
|
||||
import numpy as np
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||
|
||||
import numpy as np
|
||||
|
||||
from TTS.tts.datasets.formatters import *
|
||||
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||
|
||||
####################
|
||||
# UTILITIES
|
||||
|
|
|
@ -7,7 +7,6 @@ from typing import List
|
|||
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
########################
|
||||
# DATASETS
|
||||
########################
|
||||
|
|
|
@ -4,13 +4,13 @@ import torch.nn as nn
|
|||
from TTS.tts.layers.align_tts.mdn import MDNBlock
|
||||
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
class AlignTTS(nn.Module):
|
||||
|
|
|
@ -6,11 +6,11 @@ from torch.nn import functional as F
|
|||
|
||||
from TTS.tts.layers.glow_tts.decoder import Decoder
|
||||
from TTS.tts.layers.glow_tts.encoder import Encoder
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
|
||||
|
||||
class GlowTTS(nn.Module):
|
||||
|
|
|
@ -3,13 +3,13 @@ from torch import nn
|
|||
|
||||
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
class SpeedySpeech(nn.Module):
|
||||
|
|
|
@ -2,11 +2,11 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||
from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG
|
||||
from TTS.tts.models.tacotron_abstract import TacotronAbstract
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
||||
|
||||
class Tacotron(TacotronAbstract):
|
||||
|
|
|
@ -3,11 +3,11 @@ import numpy as np
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||
from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
|
||||
from TTS.tts.models.tacotron_abstract import TacotronAbstract
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
||||
|
||||
class Tacotron2(TacotronAbstract):
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def _pad_data(x, length):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import Union, List, Any
|
||||
from typing import Any, List, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
|
@ -11,9 +11,9 @@ import torch
|
|||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.utils.text.symbols import parse_symbols
|
||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
||||
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
|
||||
from TTS.utils.io import copy_model_files
|
||||
from TTS.utils.logging import ConsoleLogger, TensorboardLogger
|
||||
|
||||
|
||||
def init_arguments(argv):
|
||||
|
|
Loading…
Reference in New Issue