Refactor GlowTTS model and recipe for TTSTokenizer

pull/1324/head
Eren Gölge 2021-11-16 13:36:35 +01:00
parent 5a9653978a
commit bd461ace33
4 changed files with 54 additions and 41 deletions

View File

@ -22,10 +22,13 @@ class BaseModel(nn.Module, ABC):
def __init__(self, config: Coqpit):
super().__init__()
self._set_model_args(config)
def _set_model_args(self, config: Coqpit):
"""Set model arguments from the config. Override this."""
@staticmethod
def init_from_config(config: Coqpit):
"""Init the model from given config.
Override this depending on your model.
"""
pass
@abstractmethod

View File

@ -15,7 +15,7 @@ from TTS.tts.datasets.dataset import TTSDataset
from TTS.tts.utils.languages import LanguageManager, get_language_weighted_sampler
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager, get_speaker_weighted_sampler
from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text import make_symbols
from TTS.tts.utils.text.symbols import Graphemes, make_symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
# pylint: skip-file
@ -34,8 +34,20 @@ class BaseTTS(BaseModel):
- 1D tensors `batch x 1`
"""
def __init__(self, config: Coqpit, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None):
super().__init__(config)
self.config = config
self.ap = ap
self.tokenizer = tokenizer
self.speaker_manager = speaker_manager
self._set_model_args(config)
def _set_model_args(self, config: Coqpit):
"""Setup model args based on the config type.
"""Setup model args based on the config type (`ModelConfig` or `ModelArgs`).
`ModelArgs` has all the fields reuqired to initialize the model architecture.
`ModelConfig` has all the fields required for training, inference and containes `ModelArgs`.
If the config is for training with a name like "*Config", then the model args are embeded in the
config.model_args
@ -44,8 +56,8 @@ class BaseTTS(BaseModel):
"""
# don't use isintance not to import recursively
if "Config" in config.__class__.__name__:
num_chars = self.config.model_args.num_chars if self.tokenizer is None else self.tokenizer.characters.num_chars
if "characters" in config:
_, self.config, num_chars = self.get_characters(config)
self.config.num_chars = num_chars
if hasattr(self.config, "model_args"):
config.model_args.num_chars = num_chars
@ -58,18 +70,21 @@ class BaseTTS(BaseModel):
else:
raise ValueError("config must be either a *Config or *Args")
@staticmethod
def get_characters(config: Coqpit) -> str:
# TODO: implement CharacterProcessor
if config.characters is not None:
symbols, phonemes = make_symbols(**config.characters)
else:
from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols
# @staticmethod
# def get_characters(config: Coqpit) -> str:
# # TODO: implement CharacterProcessor
# if config.characters is not None:
# symbols, phonemes = make_symbols(**config.characters)
# else:
# from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols
config.characters = CharactersConfig(**parse_symbols())
model_characters = phonemes if config.use_phonemes else symbols
num_chars = len(model_characters) + getattr(config, "add_blank", False)
return model_characters, config, num_chars
# if config.use_phonemes:
# config.characters = Graphemes()
# model_characters = phonemes if config.use_phonemes else symbols
# num_chars = len(model_characters) + getattr(config, "add_blank", False)
# return model_characters, config, num_chars
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
return get_speaker_manager(config, restore_path, data, out_path)
@ -247,8 +262,6 @@ class BaseTTS(BaseModel):
if is_eval and not config.run_eval:
loader = None
else:
ap = assets["audio_processor"]
# setup multi-speaker attributes
if hasattr(self, "speaker_manager") and self.speaker_manager is not None:
if hasattr(config, "model_args"):
@ -279,28 +292,21 @@ class BaseTTS(BaseModel):
# init dataloader
dataset = TTSDataset(
outputs_per_step=config.r if "r" in config else 1,
text_cleaner=config.text_cleaner,
compute_linear_spec=config.model.lower() == "tacotron" or config.compute_linear_spec,
compute_f0=config.get("compute_f0", False),
f0_cache_path=config.get("f0_cache_path", None),
meta_data=data_items,
ap=ap,
characters=config.characters,
custom_symbols=custom_symbols,
add_blank=config["add_blank"],
ap=self.ap,
return_wav=config.return_wav if "return_wav" in config else False,
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
min_seq_len=config.min_seq_len,
max_seq_len=config.max_seq_len,
phoneme_cache_path=config.phoneme_cache_path,
use_phonemes=config.use_phonemes,
phoneme_language=config.phoneme_language,
enable_eos_bos=config.enable_eos_bos_chars,
use_noise_augment=False if is_eval else config.use_noise_augment,
verbose=verbose,
speaker_id_mapping=speaker_id_mapping,
d_vector_mapping=d_vector_mapping,
language_id_mapping=language_id_mapping,
d_vector_mapping=d_vector_mapping if config.use_d_vector_file else None,
tokenizer=self.tokenizer
)
# pre-compute phonemes
@ -332,7 +338,7 @@ class BaseTTS(BaseModel):
if config.compute_f0 and rank in [None, 0]:
if not os.path.exists(config.f0_cache_path):
dataset.pitch_extractor.compute_pitch(
ap, config.get("f0_cache_path", None), config.num_loader_workers
self.ap, config.get("f0_cache_path", None), config.num_loader_workers
)
# halt DDP processes for the main process to finish computing the F0 cache
@ -404,6 +410,7 @@ class BaseTTS(BaseModel):
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
ap = assets["audio_processor"]
tokenizer = assets["tokenizer"]
print(" | > Synthesizing test sentences.")
test_audios = {}
test_figures = {}
@ -416,6 +423,7 @@ class BaseTTS(BaseModel):
self.config,
"cuda" in str(next(self.parameters()).device),
ap,
tokenizer,
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],

View File

@ -46,11 +46,9 @@ class GlowTTS(BaseTTS):
"""
def __init__(self, config: GlowTTSConfig, speaker_manager: SpeakerManager = None):
def __init__(self, config: GlowTTSConfig, ap: "AudioProcessor", tokenizer: "TTSTokenizer", speaker_manager: SpeakerManager = None):
super().__init__(config)
self.speaker_manager = speaker_manager
super().__init__(config, ap, tokenizer, speaker_manager)
# pass all config fields to `self`
# for fewer code change
@ -58,7 +56,7 @@ class GlowTTS(BaseTTS):
for key in config:
setattr(self, key, config[key])
_, self.config, self.num_chars = self.get_characters(config)
self.num_chars = self.tokenizer.characters.num_chars
self.decoder_output_dim = config.out_channels
# init multi-speaker layers if necessary
@ -448,7 +446,6 @@ class GlowTTS(BaseTTS):
Returns:
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
"""
ap = assets["audio_processor"]
print(" | > Synthesizing test sentences.")
test_audios = {}
test_figures = {}
@ -463,7 +460,8 @@ class GlowTTS(BaseTTS):
sen,
self.config,
"cuda" in str(next(self.parameters()).device),
ap,
self.ap,
self.tokenizer,
speaker_id=aux_inputs["speaker_id"],
d_vector=aux_inputs["d_vector"],
style_wav=aux_inputs["style_wav"],

View File

@ -11,6 +11,7 @@ from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.glow_tts import GlowTTS
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
# we use the same path as this script as our training folder.
@ -47,7 +48,11 @@ config = GlowTTSConfig(
# INITIALIZE THE AUDIO PROCESSOR
# Audio processor is used for feature extraction and audio I/O.
# It mainly serves to the dataloader and the training loggers.
ap = AudioProcessor(**config.audio.to_dict())
ap = AudioProcessor.init_from_config(config)
# INITIALIZE THE TOKENIZER
# Tokenizer is used to convert text to sequences of token IDs.
tokenizer = TTSTokenizer.init_from_config(config)
# LOAD DATA SAMPLES
# Each sample is a list of ```[text, audio_file_path, speaker_name]```
@ -60,7 +65,7 @@ train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
# Models take a config object and a speaker manager as input
# Config defines the details of the model like the number of layers, the size of the embedding, etc.
# Speaker manager is used by multi-speaker models.
model = GlowTTS(config, speaker_manager=None)
model = GlowTTS(config, ap, tokenizer, speaker_manager=None)
# INITIALIZE THE TRAINER
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
@ -71,8 +76,7 @@ trainer = Trainer(
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
training_assets={"audio_processor": ap}, # assets are objetcs used by the models but not class members.
eval_samples=eval_samples
)
# AND... 3,2,1... 🚀