mirror of https://github.com/coqui-ai/TTS.git
Doc update (#889)
* Link source files from the docs * Update glowTTS recipes for docs * Add dataset downloaderspull/888/head
parent
0cac3f330a
commit
035ed432bc
|
@ -36,7 +36,9 @@ def split_dataset(items):
|
|||
return items[:eval_split_size], items[eval_split_size:]
|
||||
|
||||
|
||||
def load_tts_samples(datasets: Union[List[Dict], Dict], eval_split=True, formatter: Callable=None) -> Tuple[List[List], List[List]]:
|
||||
def load_tts_samples(
|
||||
datasets: Union[List[Dict], Dict], eval_split=True, formatter: Callable = None
|
||||
) -> Tuple[List[List], List[List]]:
|
||||
"""Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided.
|
||||
If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based
|
||||
on the dataset name.
|
||||
|
|
|
@ -0,0 +1,185 @@
|
|||
# Adapted from https://github.com/pytorch/audio/
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import tarfile
|
||||
import urllib
|
||||
import urllib.request
|
||||
import zipfile
|
||||
from typing import Any, Iterable, List, Optional
|
||||
|
||||
from torch.utils.model_zoo import tqdm
|
||||
|
||||
|
||||
def stream_url(
|
||||
url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True
|
||||
) -> Iterable:
|
||||
"""Stream url by chunk
|
||||
|
||||
Args:
|
||||
url (str): Url.
|
||||
start_byte (int or None, optional): Start streaming at that point (Default: ``None``).
|
||||
block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``).
|
||||
progress_bar (bool, optional): Display a progress bar (Default: ``True``).
|
||||
"""
|
||||
|
||||
# If we already have the whole file, there is no need to download it again
|
||||
req = urllib.request.Request(url, method="HEAD")
|
||||
with urllib.request.urlopen(req) as response:
|
||||
url_size = int(response.info().get("Content-Length", -1))
|
||||
if url_size == start_byte:
|
||||
return
|
||||
|
||||
req = urllib.request.Request(url)
|
||||
if start_byte:
|
||||
req.headers["Range"] = "bytes={}-".format(start_byte)
|
||||
|
||||
with urllib.request.urlopen(req) as upointer, tqdm(
|
||||
unit="B",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=url_size,
|
||||
disable=not progress_bar,
|
||||
) as pbar:
|
||||
|
||||
num_bytes = 0
|
||||
while True:
|
||||
chunk = upointer.read(block_size)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
num_bytes += len(chunk)
|
||||
pbar.update(len(chunk))
|
||||
|
||||
|
||||
def download_url(
|
||||
url: str,
|
||||
download_folder: str,
|
||||
filename: Optional[str] = None,
|
||||
hash_value: Optional[str] = None,
|
||||
hash_type: str = "sha256",
|
||||
progress_bar: bool = True,
|
||||
resume: bool = False,
|
||||
) -> None:
|
||||
"""Download file to disk.
|
||||
|
||||
Args:
|
||||
url (str): Url.
|
||||
download_folder (str): Folder to download file.
|
||||
filename (str or None, optional): Name of downloaded file. If None, it is inferred from the url
|
||||
(Default: ``None``).
|
||||
hash_value (str or None, optional): Hash for url (Default: ``None``).
|
||||
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
|
||||
progress_bar (bool, optional): Display a progress bar (Default: ``True``).
|
||||
resume (bool, optional): Enable resuming download (Default: ``False``).
|
||||
"""
|
||||
|
||||
req = urllib.request.Request(url, method="HEAD")
|
||||
req_info = urllib.request.urlopen(req).info() # pylint: disable=consider-using-with
|
||||
|
||||
# Detect filename
|
||||
filename = filename or req_info.get_filename() or os.path.basename(url)
|
||||
filepath = os.path.join(download_folder, filename)
|
||||
if resume and os.path.exists(filepath):
|
||||
mode = "ab"
|
||||
local_size: Optional[int] = os.path.getsize(filepath)
|
||||
|
||||
elif not resume and os.path.exists(filepath):
|
||||
raise RuntimeError("{} already exists. Delete the file manually and retry.".format(filepath))
|
||||
else:
|
||||
mode = "wb"
|
||||
local_size = None
|
||||
|
||||
if hash_value and local_size == int(req_info.get("Content-Length", -1)):
|
||||
with open(filepath, "rb") as file_obj:
|
||||
if validate_file(file_obj, hash_value, hash_type):
|
||||
return
|
||||
raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
|
||||
|
||||
with open(filepath, mode) as fpointer:
|
||||
for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar):
|
||||
fpointer.write(chunk)
|
||||
|
||||
with open(filepath, "rb") as file_obj:
|
||||
if hash_value and not validate_file(file_obj, hash_value, hash_type):
|
||||
raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
|
||||
|
||||
|
||||
def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool:
|
||||
"""Validate a given file object with its hash.
|
||||
|
||||
Args:
|
||||
file_obj: File object to read from.
|
||||
hash_value (str): Hash for url.
|
||||
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
|
||||
|
||||
Returns:
|
||||
bool: return True if its a valid file, else False.
|
||||
"""
|
||||
|
||||
if hash_type == "sha256":
|
||||
hash_func = hashlib.sha256()
|
||||
elif hash_type == "md5":
|
||||
hash_func = hashlib.md5()
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
while True:
|
||||
# Read by chunk to avoid filling memory
|
||||
chunk = file_obj.read(1024 ** 2)
|
||||
if not chunk:
|
||||
break
|
||||
hash_func.update(chunk)
|
||||
|
||||
return hash_func.hexdigest() == hash_value
|
||||
|
||||
|
||||
def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
|
||||
"""Extract archive.
|
||||
Args:
|
||||
from_path (str): the path of the archive.
|
||||
to_path (str or None, optional): the root path of the extraced files (directory of from_path)
|
||||
(Default: ``None``)
|
||||
overwrite (bool, optional): overwrite existing files (Default: ``False``)
|
||||
|
||||
Returns:
|
||||
list: List of paths to extracted files even if not overwritten.
|
||||
"""
|
||||
|
||||
if to_path is None:
|
||||
to_path = os.path.dirname(from_path)
|
||||
|
||||
try:
|
||||
with tarfile.open(from_path, "r") as tar:
|
||||
logging.info("Opened tar file %s.", from_path)
|
||||
files = []
|
||||
for file_ in tar: # type: Any
|
||||
file_path = os.path.join(to_path, file_.name)
|
||||
if file_.isfile():
|
||||
files.append(file_path)
|
||||
if os.path.exists(file_path):
|
||||
logging.info("%s already extracted.", file_path)
|
||||
if not overwrite:
|
||||
continue
|
||||
tar.extract(file_, to_path)
|
||||
return files
|
||||
except tarfile.ReadError:
|
||||
pass
|
||||
|
||||
try:
|
||||
with zipfile.ZipFile(from_path, "r") as zfile:
|
||||
logging.info("Opened zip file %s.", from_path)
|
||||
files = zfile.namelist()
|
||||
for file_ in files:
|
||||
file_path = os.path.join(to_path, file_)
|
||||
if os.path.exists(file_path):
|
||||
logging.info("%s already extracted.", file_path)
|
||||
if not overwrite:
|
||||
continue
|
||||
zfile.extract(file_, to_path)
|
||||
return files
|
||||
except zipfile.BadZipFile:
|
||||
pass
|
||||
|
||||
raise NotImplementedError(" > [!] only supports tar.gz, tgz, and zip achives.")
|
|
@ -0,0 +1,33 @@
|
|||
import os
|
||||
|
||||
from TTS.utils.download import download_url, extract_archive
|
||||
|
||||
|
||||
def download_ljspeech(path: str):
|
||||
"""Download and extract LJSpeech dataset
|
||||
|
||||
Args:
|
||||
path (str): path to the directory where the dataset will be stored.
|
||||
"""
|
||||
os.makedirs(path, exist_ok=True)
|
||||
url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
|
||||
download_url(url, path)
|
||||
basename = os.path.basename(url)
|
||||
archive = os.path.join(path, basename)
|
||||
print(" > Extracting archive file...")
|
||||
extract_archive(archive)
|
||||
|
||||
|
||||
def download_vctk(path: str):
|
||||
"""Download and extract VCTK dataset
|
||||
|
||||
Args:
|
||||
path (str): path to the directory where the dataset will be stored.
|
||||
"""
|
||||
os.makedirs(path, exist_ok=True)
|
||||
url = "https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"
|
||||
download_url(url, path)
|
||||
basename = os.path.basename(url)
|
||||
archive = os.path.join(path, basename)
|
||||
print(" > Extracting archive file...")
|
||||
extract_archive(archive)
|
|
@ -39,20 +39,6 @@ master_doc = "index"
|
|||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
]
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'TODO/*']
|
||||
|
||||
source_suffix = [".rst", ".md"]
|
||||
|
||||
# extensions
|
||||
extensions = [
|
||||
'sphinx.ext.autodoc',
|
||||
'sphinx.ext.autosummary',
|
||||
|
@ -68,6 +54,17 @@ extensions = [
|
|||
"sphinx_inline_tabs",
|
||||
]
|
||||
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'TODO/*']
|
||||
|
||||
source_suffix = [".rst", ".md"]
|
||||
|
||||
myst_enable_extensions = ['linkify',]
|
||||
|
||||
# 'sphinxcontrib.katex',
|
||||
|
|
|
@ -15,62 +15,7 @@
|
|||
`Nervous Beginners`.
|
||||
A recipe for `GlowTTS` using `LJSpeech` dataset looks like below. Let's be creative and call this `train_glowtts.py`.
|
||||
|
||||
```python
|
||||
# train_glowtts.py
|
||||
|
||||
import os
|
||||
|
||||
from TTS.trainer import Trainer, TrainingArgs
|
||||
from TTS.tts.configs.shared_config import BaseDatasetConfig
|
||||
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.glow_tts import GlowTTS
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
dataset_config = BaseDatasetConfig(
|
||||
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||
)
|
||||
config = GlowTTSConfig(
|
||||
batch_size=32,
|
||||
eval_batch_size=16,
|
||||
num_loader_workers=4,
|
||||
num_eval_loader_workers=4,
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1000,
|
||||
text_cleaner="phoneme_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||
print_step=25,
|
||||
print_eval=False,
|
||||
mixed_precision=True,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
)
|
||||
|
||||
# init audio processor
|
||||
ap = AudioProcessor(**config.audio.to_dict())
|
||||
|
||||
# load training samples
|
||||
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
||||
|
||||
# init model
|
||||
model = GlowTTS(config)
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
TrainingArgs(),
|
||||
config,
|
||||
output_path,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
training_assets={"audio_processor": ap},
|
||||
)
|
||||
trainer.fit()
|
||||
|
||||
```{literalinclude} ../../recipes/ljspeech/glow_tts/train_glowtts.py
|
||||
```
|
||||
|
||||
You need to change fields of the `BaseDatasetConfig` to match your dataset and then update `GlowTTSConfig`
|
||||
|
@ -162,7 +107,7 @@
|
|||
$ tensorboard --logdir=<path to your training directory>
|
||||
```
|
||||
|
||||
6. Monitor the training process.
|
||||
6. Monitor the training progress.
|
||||
|
||||
On the terminal and Tensorboard, you can monitor the progress of your model. Also Tensorboard provides certain figures and sample outputs.
|
||||
|
||||
|
@ -197,68 +142,5 @@ d-vectors. For using d-vectors, you first need to compute the d-vectors using th
|
|||
|
||||
The same Glow-TTS model above can be trained on a multi-speaker VCTK dataset with the script below.
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.trainer import Trainer, TrainingArgs
|
||||
from TTS.tts import BaseDatasetConfig, GlowTTSConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.glow_tts import GlowTTS
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
# define dataset config for VCTK
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/"))
|
||||
|
||||
# init audio processing config
|
||||
audio_config = BaseAudioConfig(sample_rate=22050, do_trim_silence=True, trim_db=23.0)
|
||||
|
||||
# init training config
|
||||
config = GlowTTSConfig(
|
||||
batch_size=64,
|
||||
eval_batch_size=16,
|
||||
num_loader_workers=4,
|
||||
num_eval_loader_workers=4,
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1000,
|
||||
text_cleaner="phoneme_cleaners",
|
||||
use_phonemes=True,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||
print_step=25,
|
||||
print_eval=False,
|
||||
mixed_precision=True,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config],
|
||||
use_speaker_embedding=True,
|
||||
)
|
||||
|
||||
# init audio processor
|
||||
ap = AudioProcessor(**config.audio.to_dict())
|
||||
|
||||
# load training samples
|
||||
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
||||
|
||||
# ONLY FOR MULTI-SPEAKER: init speaker manager for multi-speaker training
|
||||
speaker_manager = SpeakerManager()
|
||||
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
|
||||
config.num_speakers = speaker_manager.num_speakers
|
||||
|
||||
# init model
|
||||
model = GlowTTS(config, speaker_manager)
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
TrainingArgs(),
|
||||
config,
|
||||
output_path,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
training_assets={"audio_processor": ap},
|
||||
)
|
||||
trainer.fit()
|
||||
```
|
||||
```{literalinclude} ../../recipes/vctk/glow_tts/train_glow_tts.py
|
||||
```
|
|
@ -18,80 +18,21 @@ $ pip install -e .
|
|||
|
||||
## Training a `tts` Model
|
||||
|
||||
A breakdown of a simple script training a GlowTTS model on LJspeech dataset. See the comments for the explanation of
|
||||
each line.
|
||||
A breakdown of a simple script that trains a GlowTTS model on the LJspeech dataset. See the comments for more details.
|
||||
|
||||
### Pure Python Way
|
||||
|
||||
0. Download your dataset.
|
||||
|
||||
In this example, we download and use the LJSpeech dataset. Set the download directory based on your preferences.
|
||||
|
||||
```bash
|
||||
$ python -c 'from TTS.utils.downloaders import download_ljspeech; download_ljspeech("../recipes/ljspeech/");'
|
||||
```
|
||||
|
||||
1. Define `train.py`.
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
# GlowTTSConfig: all model related values for training, validating and testing.
|
||||
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||
|
||||
# BaseDatasetConfig: defines name, formatter and path of the dataset.
|
||||
from TTS.tts.configs.shared_config import BaseDatasetConfig
|
||||
|
||||
# init_training: Initialize and setup the training environment.
|
||||
# Trainer: Where the ✨️ happens.
|
||||
# TrainingArgs: Defines the set of arguments of the Trainer.
|
||||
from TTS.trainer import init_training, Trainer, TrainingArgs
|
||||
|
||||
# we use the same path as this script as our training folder.
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# set LJSpeech as our target dataset and define its path so that the Trainer knows what data formatter it needs.
|
||||
dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/"))
|
||||
|
||||
# Configure the model. Every config class inherits the BaseTTSConfig to have all the fields defined for the Trainer.
|
||||
config = GlowTTSConfig(
|
||||
batch_size=32,
|
||||
eval_batch_size=16,
|
||||
num_loader_workers=4,
|
||||
num_eval_loader_workers=4,
|
||||
run_eval=True,
|
||||
test_delay_epochs=-1,
|
||||
epochs=1000,
|
||||
text_cleaner="english_cleaners",
|
||||
use_phonemes=False,
|
||||
phoneme_language="en-us",
|
||||
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
|
||||
print_step=25,
|
||||
print_eval=True,
|
||||
mixed_precision=False,
|
||||
output_path=output_path,
|
||||
datasets=[dataset_config]
|
||||
)
|
||||
|
||||
# initialize the audio processor used for feature extraction and audio I/O.
|
||||
# It is mainly used by the dataloader and the training loggers.
|
||||
ap = AudioProcessor(**config.audio.to_dict())
|
||||
|
||||
# load a list of training samples
|
||||
# Each sample is a list of ```[text, audio_file_path, speaker_name]```
|
||||
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
||||
|
||||
# initialize the model
|
||||
# Models only takes the config object as input.
|
||||
model = GlowTTS(config)
|
||||
|
||||
# Initiate the Trainer.
|
||||
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
|
||||
# distributed training, etc.
|
||||
trainer = Trainer(
|
||||
TrainingArgs(),
|
||||
config,
|
||||
output_path,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
training_assets={"audio_processor": ap},
|
||||
)
|
||||
|
||||
# And kick it 🚀
|
||||
trainer.fit()
|
||||
```{literalinclude} ../../recipes/ljspeech/glow_tts/train_glowtts.py
|
||||
```
|
||||
|
||||
2. Run the script.
|
||||
|
@ -154,58 +95,7 @@ We still support running training from CLI like in the old days. The same traini
|
|||
|
||||
## Training a `vocoder` Model
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
from TTS.trainer import Trainer, TrainingArgs
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.vocoder.configs import HifiganConfig
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data
|
||||
from TTS.vocoder.models.gan import GAN
|
||||
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
config = HifiganConfig(
|
||||
batch_size=32,
|
||||
eval_batch_size=16,
|
||||
num_loader_workers=4,
|
||||
num_eval_loader_workers=4,
|
||||
run_eval=True,
|
||||
test_delay_epochs=5,
|
||||
epochs=1000,
|
||||
seq_len=8192,
|
||||
pad_short=2000,
|
||||
use_noise_augment=True,
|
||||
eval_split_size=10,
|
||||
print_step=25,
|
||||
print_eval=False,
|
||||
mixed_precision=False,
|
||||
lr_gen=1e-4,
|
||||
lr_disc=1e-4,
|
||||
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
|
||||
output_path=output_path,
|
||||
)
|
||||
|
||||
# init audio processor
|
||||
ap = AudioProcessor(**config.audio.to_dict())
|
||||
|
||||
# load training samples
|
||||
eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size)
|
||||
|
||||
# init model
|
||||
model = GAN(config)
|
||||
|
||||
# init the trainer and 🚀
|
||||
trainer = Trainer(
|
||||
TrainingArgs(),
|
||||
config,
|
||||
output_path,
|
||||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
training_assets={"audio_processor": ap},
|
||||
)
|
||||
trainer.fit()
|
||||
```{literalinclude} ../../recipes/ljspeech/hifigan/train_hifigan.py
|
||||
```
|
||||
|
||||
❗️ Note that you can also use ```train_vocoder.py``` as the ```tts``` models above.
|
||||
|
|
|
@ -1,16 +1,30 @@
|
|||
import os
|
||||
|
||||
# Trainer: Where the ✨️ happens.
|
||||
# TrainingArgs: Defines the set of arguments of the Trainer.
|
||||
from TTS.trainer import Trainer, TrainingArgs
|
||||
|
||||
# GlowTTSConfig: all model related values for training, validating and testing.
|
||||
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
|
||||
|
||||
# BaseDatasetConfig: defines name, formatter and path of the dataset.
|
||||
from TTS.tts.configs.shared_configs import BaseDatasetConfig
|
||||
from TTS.tts.datasets import load_tts_samples
|
||||
from TTS.tts.models.glow_tts import GlowTTS
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
# we use the same path as this script as our training folder.
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# DEFINE DATASET CONFIG
|
||||
# Set LJSpeech as our target dataset and define its path.
|
||||
# You can also use a simple Dict to define the dataset and pass it to your custom formatter.
|
||||
dataset_config = BaseDatasetConfig(
|
||||
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
|
||||
)
|
||||
|
||||
# INITIALIZE THE TRAINING CONFIGURATION
|
||||
# Configure the model. Every config class inherits the BaseTTSConfig.
|
||||
config = GlowTTSConfig(
|
||||
batch_size=32,
|
||||
eval_batch_size=16,
|
||||
|
@ -30,16 +44,27 @@ config = GlowTTSConfig(
|
|||
datasets=[dataset_config],
|
||||
)
|
||||
|
||||
# init audio processor
|
||||
# INITIALIZE THE AUDIO PROCESSOR
|
||||
# Audio processor is used for feature extraction and audio I/O.
|
||||
# It mainly serves to the dataloader and the training loggers.
|
||||
ap = AudioProcessor(**config.audio.to_dict())
|
||||
|
||||
# load training samples
|
||||
# LOAD DATA SAMPLES
|
||||
# Each sample is a list of ```[text, audio_file_path, speaker_name]```
|
||||
# You can define your custom sample loader returning the list of samples.
|
||||
# Or define your custom formatter and pass it to the `load_tts_samples`.
|
||||
# Check `TTS.tts.datasets.load_tts_samples` for more details.
|
||||
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
|
||||
|
||||
# init model
|
||||
model = GlowTTS(config)
|
||||
# INITIALIZE THE MODEL
|
||||
# Models take a config object and a speaker manager as input
|
||||
# Config defines the details of the model like the number of layers, the size of the embedding, etc.
|
||||
# Speaker manager is used by multi-speaker models.
|
||||
model = GlowTTS(config, speaker_manager=None)
|
||||
|
||||
# init the trainer and 🚀
|
||||
# INITIALIZE THE TRAINER
|
||||
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
|
||||
# distributed training, etc.
|
||||
trainer = Trainer(
|
||||
TrainingArgs(),
|
||||
config,
|
||||
|
@ -47,6 +72,8 @@ trainer = Trainer(
|
|||
model=model,
|
||||
train_samples=train_samples,
|
||||
eval_samples=eval_samples,
|
||||
training_assets={"audio_processor": ap},
|
||||
training_assets={"audio_processor": ap}, # assets are objetcs used by the models but not class members.
|
||||
)
|
||||
|
||||
# AND... 3,2,1... 🚀
|
||||
trainer.fit()
|
||||
|
|
|
@ -9,11 +9,24 @@ from TTS.tts.models.glow_tts import GlowTTS
|
|||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
# set experiment paths
|
||||
output_path = os.path.dirname(os.path.abspath(__file__))
|
||||
dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/"))
|
||||
dataset_path = os.path.join(output_path, "../VCTK/")
|
||||
|
||||
audio_config = BaseAudioConfig(sample_rate=22050, do_trim_silence=True, trim_db=23.0)
|
||||
# download the dataset if not downloaded
|
||||
if not os.path.exists(dataset_path):
|
||||
from TTS.utils.downloaders import download_vctk
|
||||
|
||||
download_vctk(dataset_path)
|
||||
|
||||
# define dataset config
|
||||
dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=dataset_path)
|
||||
|
||||
# define audio config
|
||||
# ❗ resample the dataset externally using `TTS/bin/resample.py` and set `resample=False` for faster training
|
||||
audio_config = BaseAudioConfig(sample_rate=22050, resample=True, do_trim_silence=True, trim_db=23.0)
|
||||
|
||||
# define model config
|
||||
config = GlowTTSConfig(
|
||||
batch_size=64,
|
||||
eval_batch_size=16,
|
||||
|
|
Loading…
Reference in New Issue