mirror of https://github.com/coqui-ai/TTS.git
Refactor GlowTTS model and recipe for TTSTokenizer
parent
5a9653978a
commit
bd461ace33
|
@ -22,10 +22,13 @@ class BaseModel(nn.Module, ABC):
|
|||
|
||||
def __init__(self, config: Coqpit):
|
||||
super().__init__()
|
||||
self._set_model_args(config)
|
||||
|
||||
def _set_model_args(self, config: Coqpit):
|
||||
"""Set model arguments from the config. Override this."""
|
||||
@staticmethod
|
||||
def init_from_config(config: Coqpit):
|
||||
"""Init the model from given config.
|
||||
|
||||
Override this depending on your model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
@ -15,7 +15,7 @@ from TTS.tts.datasets.dataset import TTSDataset
|
|||
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text import make_symbols
|
||||
from TTS.tts.utils.text.symbols import Graphemes, make_symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
|
||||
# pylint: skip-file
|
||||
|
@ -34,8 +34,20 @@ class BaseTTS(BaseModel):
|
|||
- 1D tensors `batch x 1`
|
||||
"""
|
||||
|
||||
def __init__(self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
self.tokenizer = tokenizer
|
||||
self.speaker_manager = speaker_manager
|
||||
self._set_model_args(config)
|
||||
|
||||
def _set_model_args(self, config: Coqpit):
|
||||
"""Setup model args based on the config type.
|
||||
"""Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
|
||||
|
||||
`ModelArgs` has all the fields reuqired to initialize the model architecture.
|
||||
|
||||
`ModelConfig` has all the fields required for training, inference and containes `ModelArgs`.
|
||||
|
||||
If the config is for training with a name like "*Config", then the model args are embeded in the
|
||||
config.model_args
|
||||
|
@ -44,8 +56,8 @@ class BaseTTS(BaseModel):
|
|||
"""
|
||||
# don't use isintance not to import recursively
|
||||
if "Config" in config.__class__.__name__:
|
||||
num_chars = self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
|
||||
if "characters" in config:
|
||||
_, self.config, num_chars = self.get_characters(config)
|
||||
self.config.num_chars = num_chars
|
||||
if hasattr(self.config, "model_args"):
|
||||
config.model_args.num_chars = num_chars
|
||||
|
@ -58,18 +70,21 @@ class BaseTTS(BaseModel):
|
|||
else:
|
||||
raise ValueError("config must be either a *Config or *Args")
|
||||
|
||||
@staticmethod
|
||||
def get_characters(config: Coqpit) -> str:
|
||||
# TODO: implement CharacterProcessor
|
||||
if config.characters is not None:
|
||||
symbols, phonemes = make_symbols(**config.characters)
|
||||
else:
|
||||
from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols
|
||||
# @staticmethod
|
||||
# def get_characters(config: Coqpit) -> str:
|
||||
# # TODO: implement CharacterProcessor
|
||||
# if config.characters is not None:
|
||||
# symbols, phonemes = make_symbols(**config.characters)
|
||||
# else:
|
||||
# from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols
|
||||
|
||||
config.characters = CharactersConfig(**parse_symbols())
|
||||
model_characters = phonemes if config.use_phonemes else symbols
|
||||
num_chars = len(model_characters) + getattr(config, "add_blank", False)
|
||||
return model_characters, config, num_chars
|
||||
# if config.use_phonemes:
|
||||
|
||||
# config.characters = Graphemes()
|
||||
|
||||
# model_characters = phonemes if config.use_phonemes else symbols
|
||||
# num_chars = len(model_characters) + getattr(config, "add_blank", False)
|
||||
# return model_characters, config, num_chars
|
||||
|
||||
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
|
||||
return get_speaker_manager(config, restore_path, data, out_path)
|
||||
|
@ -247,8 +262,6 @@ class BaseTTS(BaseModel):
|
|||
if is_eval and not config.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
ap = assets["audio_processor"]
|
||||
|
||||
# setup multi-speaker attributes
|
||||
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
|
||||
if hasattr(config, "model_args"):
|
||||
|
@ -279,28 +292,21 @@ class BaseTTS(BaseModel):
|
|||
# init dataloader
|
||||
dataset = TTSDataset(
|
||||
outputs_per_step=config.r if "r" in config else 1,
|
||||
text_cleaner=config.text_cleaner,
|
||||
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
|
||||
compute_f0=config.get("compute_f0", False),
|
||||
f0_cache_path=config.get("f0_cache_path", None),
|
||||
meta_data=data_items,
|
||||
ap=ap,
|
||||
characters=config.characters,
|
||||
custom_symbols=custom_symbols,
|
||||
add_blank=config["add_blank"],
|
||||
ap=self.ap,
|
||||
return_wav=config.return_wav if "return_wav" in config else False,
|
||||
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
||||
min_seq_len=config.min_seq_len,
|
||||
max_seq_len=config.max_seq_len,
|
||||
phoneme_cache_path=config.phoneme_cache_path,
|
||||
use_phonemes=config.use_phonemes,
|
||||
phoneme_language=config.phoneme_language,
|
||||
enable_eos_bos=config.enable_eos_bos_chars,
|
||||
use_noise_augment=False if is_eval else config.use_noise_augment,
|
||||
verbose=verbose,
|
||||
speaker_id_mapping=speaker_id_mapping,
|
||||
d_vector_mapping=d_vector_mapping,
|
||||
language_id_mapping=language_id_mapping,
|
||||
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
|
||||
tokenizer=self.tokenizer
|
||||
)
|
||||
|
||||
# pre-compute phonemes
|
||||
|
@ -332,7 +338,7 @@ class BaseTTS(BaseModel):
|
|||
if config.compute_f0 and rank in [None, 0]:
|
||||
if not os.path.exists(config.f0_cache_path):
|
||||
dataset.pitch_extractor.compute_pitch(
|
||||
ap, config.get("f0_cache_path", None), config.num_loader_workers
|
||||
self.ap, config.get("f0_cache_path", None), config.num_loader_workers
|
||||
)
|
||||
|
||||
# halt DDP processes for the main process to finish computing the F0 cache
|
||||
|
@ -404,6 +410,7 @@ class BaseTTS(BaseModel):
|
|||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
ap = assets["audio_processor"]
|
||||
tokenizer = assets["tokenizer"]
|
||||
print(" | > Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
|
@ -416,6 +423,7 @@ class BaseTTS(BaseModel):
|
|||
self.config,
|
||||
"cuda" in str(next(self.parameters()).device),
|
||||
ap,
|
||||
tokenizer,
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
|
|
|
@ -46,11 +46,9 @@ class GlowTTS(BaseTTS):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, config: GlowTTSConfig, speaker_manager: SpeakerManager = None):
|
||||
def __init__(self, config: GlowTTSConfig, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None):
|
||||
|
||||
super().__init__(config)
|
||||
|
||||
self.speaker_manager = speaker_manager
|
||||
super().__init__(config, ap, tokenizer, speaker_manager)
|
||||
|
||||
# pass all config fields to `self`
|
||||
# for fewer code change
|
||||
|
@ -58,7 +56,7 @@ class GlowTTS(BaseTTS):
|
|||
for key in config:
|
||||
setattr(self, key, config[key])
|
||||
|
||||
_, self.config, self.num_chars = self.get_characters(config)
|
||||
self.num_chars = self.tokenizer.characters.num_chars
|
||||
self.decoder_output_dim = config.out_channels
|
||||
|
||||
# init multi-speaker layers if necessary
|
||||
|
@ -448,7 +446,6 @@ class GlowTTS(BaseTTS):
|
|||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
ap = assets["audio_processor"]
|
||||
print(" | > Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
|
@ -463,7 +460,8 @@ class GlowTTS(BaseTTS):
|
|||
sen,
|
||||
self.config,
|
||||
"cuda" in str(next(self.parameters()).device),
|
||||
ap,
|
||||
self.ap,
|
||||
self.tokenizer,
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
|
|
|
@ -11,6 +11,7 @@ from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
|||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.glow_tts import GlowTTS
|
||||
from TTS.tts.utils.text.tokenizer import TTSTokenizer
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
# we use the same path as this script as our training folder.
|
||||
|
@ -47,7 +48,11 @@ config = GlowTTSConfig(
|
|||
# INITIALIZE THE AUDIO PROCESSOR
|
||||
# Audio processor is used for feature extraction and audio I/O.
|
||||
# It mainly serves to the dataloader and the training loggers.
|
||||
ap = AudioProcessor(**config.audio.to_dict())
|
||||
ap = AudioProcessor.init_from_config(config)
|
||||
|
||||
# INITIALIZE THE TOKENIZER
|
||||
# Tokenizer is used to convert text to sequences of token IDs.
|
||||
tokenizer = TTSTokenizer.init_from_config(config)
|
||||
|
||||
# LOAD DATA SAMPLES
|
||||
# Each sample is a list of ```[text, audio_file_path, speaker_name]```
|
||||
|
@ -60,7 +65,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
|||
# Models take a config object and a speaker manager as input
|
||||
# Config defines the details of the model like the number of layers, the size of the embedding, etc.
|
||||
# Speaker manager is used by multi-speaker models.
|
||||
model = GlowTTS(config, speaker_manager=None)
|
||||
model = GlowTTS(config, ap, tokenizer, speaker_manager=None)
|
||||
|
||||
# INITIALIZE THE TRAINER
|
||||
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
|
||||
|
@ -71,8 +76,7 @@ trainer = Trainer(
|
|||
output_path,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
training_assets={"audio_processor": ap}, # assets are objetcs used by the models but not class members.
|
||||
eval_samples=eval_samples
|
||||
)
|
||||
|
||||
# AND... 3,2,1... 🚀
|
||||
|
|
Loading…
Reference in New Issue