Doc update (#889)

* Link source files from the docs

* Update glowTTS recipes for docs

* Add dataset downloaders
pull/888/head
Eren Gölge 2021-10-26 17:41:33 +02:00 committed by GitHub
parent 0cac3f330a
commit 035ed432bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 295 additions and 266 deletions

View File

@ -36,7 +36,9 @@ def split_dataset(items):
return items[:eval_split_size], items[eval_split_size:]
def load_tts_samples(datasets: Union[List[Dict], Dict], eval_split=True, formatter: Callable=None) -> Tuple[List[List], List[List]]:
def load_tts_samples(
datasets: Union[List[Dict], Dict], eval_split=True, formatter: Callable = None
) -> Tuple[List[List], List[List]]:
"""Parse the dataset from the datasets config, load the samples as a List and load the attention alignments if provided.
If `formatter` is not None, apply the formatter to the samples else pick the formatter from the available ones based
on the dataset name.

185
TTS/utils/download.py Normal file
View File

@ -0,0 +1,185 @@
# Adapted from https://github.com/pytorch/audio/
import hashlib
import logging
import os
import tarfile
import urllib
import urllib.request
import zipfile
from typing import Any, Iterable, List, Optional
from torch.utils.model_zoo import tqdm
def stream_url(
url: str, start_byte: Optional[int] = None, block_size: int = 32 * 1024, progress_bar: bool = True
) -> Iterable:
"""Stream url by chunk
Args:
url (str): Url.
start_byte (int or None, optional): Start streaming at that point (Default: ``None``).
block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``).
progress_bar (bool, optional): Display a progress bar (Default: ``True``).
"""
# If we already have the whole file, there is no need to download it again
req = urllib.request.Request(url, method="HEAD")
with urllib.request.urlopen(req) as response:
url_size = int(response.info().get("Content-Length", -1))
if url_size == start_byte:
return
req = urllib.request.Request(url)
if start_byte:
req.headers["Range"] = "bytes={}-".format(start_byte)
with urllib.request.urlopen(req) as upointer, tqdm(
unit="B",
unit_scale=True,
unit_divisor=1024,
total=url_size,
disable=not progress_bar,
) as pbar:
num_bytes = 0
while True:
chunk = upointer.read(block_size)
if not chunk:
break
yield chunk
num_bytes += len(chunk)
pbar.update(len(chunk))
def download_url(
url: str,
download_folder: str,
filename: Optional[str] = None,
hash_value: Optional[str] = None,
hash_type: str = "sha256",
progress_bar: bool = True,
resume: bool = False,
) -> None:
"""Download file to disk.
Args:
url (str): Url.
download_folder (str): Folder to download file.
filename (str or None, optional): Name of downloaded file. If None, it is inferred from the url
(Default: ``None``).
hash_value (str or None, optional): Hash for url (Default: ``None``).
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
progress_bar (bool, optional): Display a progress bar (Default: ``True``).
resume (bool, optional): Enable resuming download (Default: ``False``).
"""
req = urllib.request.Request(url, method="HEAD")
req_info = urllib.request.urlopen(req).info() # pylint: disable=consider-using-with
# Detect filename
filename = filename or req_info.get_filename() or os.path.basename(url)
filepath = os.path.join(download_folder, filename)
if resume and os.path.exists(filepath):
mode = "ab"
local_size: Optional[int] = os.path.getsize(filepath)
elif not resume and os.path.exists(filepath):
raise RuntimeError("{} already exists. Delete the file manually and retry.".format(filepath))
else:
mode = "wb"
local_size = None
if hash_value and local_size == int(req_info.get("Content-Length", -1)):
with open(filepath, "rb") as file_obj:
if validate_file(file_obj, hash_value, hash_type):
return
raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
with open(filepath, mode) as fpointer:
for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar):
fpointer.write(chunk)
with open(filepath, "rb") as file_obj:
if hash_value and not validate_file(file_obj, hash_value, hash_type):
raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(filepath))
def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool:
"""Validate a given file object with its hash.
Args:
file_obj: File object to read from.
hash_value (str): Hash for url.
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
Returns:
bool: return True if its a valid file, else False.
"""
if hash_type == "sha256":
hash_func = hashlib.sha256()
elif hash_type == "md5":
hash_func = hashlib.md5()
else:
raise ValueError
while True:
# Read by chunk to avoid filling memory
chunk = file_obj.read(1024 ** 2)
if not chunk:
break
hash_func.update(chunk)
return hash_func.hexdigest() == hash_value
def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
"""Extract archive.
Args:
from_path (str): the path of the archive.
to_path (str or None, optional): the root path of the extraced files (directory of from_path)
(Default: ``None``)
overwrite (bool, optional): overwrite existing files (Default: ``False``)
Returns:
list: List of paths to extracted files even if not overwritten.
"""
if to_path is None:
to_path = os.path.dirname(from_path)
try:
with tarfile.open(from_path, "r") as tar:
logging.info("Opened tar file %s.", from_path)
files = []
for file_ in tar: # type: Any
file_path = os.path.join(to_path, file_.name)
if file_.isfile():
files.append(file_path)
if os.path.exists(file_path):
logging.info("%s already extracted.", file_path)
if not overwrite:
continue
tar.extract(file_, to_path)
return files
except tarfile.ReadError:
pass
try:
with zipfile.ZipFile(from_path, "r") as zfile:
logging.info("Opened zip file %s.", from_path)
files = zfile.namelist()
for file_ in files:
file_path = os.path.join(to_path, file_)
if os.path.exists(file_path):
logging.info("%s already extracted.", file_path)
if not overwrite:
continue
zfile.extract(file_, to_path)
return files
except zipfile.BadZipFile:
pass
raise NotImplementedError(" > [!] only supports tar.gz, tgz, and zip achives.")

33
TTS/utils/downloaders.py Normal file
View File

@ -0,0 +1,33 @@
import os
from TTS.utils.download import download_url, extract_archive
def download_ljspeech(path: str):
"""Download and extract LJSpeech dataset
Args:
path (str): path to the directory where the dataset will be stored.
"""
os.makedirs(path, exist_ok=True)
url = "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"
download_url(url, path)
basename = os.path.basename(url)
archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive)
def download_vctk(path: str):
"""Download and extract VCTK dataset
Args:
path (str): path to the directory where the dataset will be stored.
"""
os.makedirs(path, exist_ok=True)
url = "https://datashare.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"
download_url(url, path)
basename = os.path.basename(url)
archive = os.path.join(path, basename)
print(" > Extracting archive file...")
extract_archive(archive)

View File

@ -39,20 +39,6 @@ master_doc = "index"
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'TODO/*']
source_suffix = [".rst", ".md"]
# extensions
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
@ -68,6 +54,17 @@ extensions = [
"sphinx_inline_tabs",
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'TODO/*']
source_suffix = [".rst", ".md"]
myst_enable_extensions = ['linkify',]
# 'sphinxcontrib.katex',

View File

@ -15,62 +15,7 @@
`Nervous Beginners`.
A recipe for `GlowTTS` using `LJSpeech` dataset looks like below. Let's be creative and call this `train_glowtts.py`.
```python
# train_glowtts.py
import os
from TTS.trainer import Trainer, TrainingArgs
from TTS.tts.configs.shared_config import BaseDatasetConfig
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.glow_tts import GlowTTS
from TTS.utils.audio import AudioProcessor
output_path = os.path.dirname(os.path.abspath(__file__))
dataset_config = BaseDatasetConfig(
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
)
config = GlowTTSConfig(
batch_size=32,
eval_batch_size=16,
num_loader_workers=4,
num_eval_loader_workers=4,
run_eval=True,
test_delay_epochs=-1,
epochs=1000,
text_cleaner="phoneme_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
print_step=25,
print_eval=False,
mixed_precision=True,
output_path=output_path,
datasets=[dataset_config],
)
# init audio processor
ap = AudioProcessor(**config.audio.to_dict())
# load training samples
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
# init model
model = GlowTTS(config)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
training_assets={"audio_processor": ap},
)
trainer.fit()
```{literalinclude} ../../recipes/ljspeech/glow_tts/train_glowtts.py
```
You need to change fields of the `BaseDatasetConfig` to match your dataset and then update `GlowTTSConfig`
@ -162,7 +107,7 @@
$ tensorboard --logdir=<path to your training directory>
```
6. Monitor the training process.
6. Monitor the training progress.
On the terminal and Tensorboard, you can monitor the progress of your model. Also Tensorboard provides certain figures and sample outputs.
@ -197,68 +142,5 @@ d-vectors. For using d-vectors, you first need to compute the d-vectors using th
The same Glow-TTS model above can be trained on a multi-speaker VCTK dataset with the script below.
```python
import os
from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs
from TTS.tts import BaseDatasetConfig, GlowTTSConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.glow_tts import GlowTTS
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
# define dataset config for VCTK
output_path = os.path.dirname(os.path.abspath(__file__))
dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/"))
# init audio processing config
audio_config = BaseAudioConfig(sample_rate=22050, do_trim_silence=True, trim_db=23.0)
# init training config
config = GlowTTSConfig(
batch_size=64,
eval_batch_size=16,
num_loader_workers=4,
num_eval_loader_workers=4,
run_eval=True,
test_delay_epochs=-1,
epochs=1000,
text_cleaner="phoneme_cleaners",
use_phonemes=True,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
print_step=25,
print_eval=False,
mixed_precision=True,
output_path=output_path,
datasets=[dataset_config],
use_speaker_embedding=True,
)
# init audio processor
ap = AudioProcessor(**config.audio.to_dict())
# load training samples
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
# ONLY FOR MULTI-SPEAKER: init speaker manager for multi-speaker training
speaker_manager = SpeakerManager()
speaker_manager.set_speaker_ids_from_data(train_samples + eval_samples)
config.num_speakers = speaker_manager.num_speakers
# init model
model = GlowTTS(config, speaker_manager)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
training_assets={"audio_processor": ap},
)
trainer.fit()
```
```{literalinclude} ../../recipes/vctk/glow_tts/train_glow_tts.py
```

View File

@ -18,80 +18,21 @@ $ pip install -e .
## Training a `tts` Model
A breakdown of a simple script training a GlowTTS model on LJspeech dataset. See the comments for the explanation of
each line.
A breakdown of a simple script that trains a GlowTTS model on the LJspeech dataset. See the comments for more details.
### Pure Python Way
0. Download your dataset.
In this example, we download and use the LJSpeech dataset. Set the download directory based on your preferences.
```bash
$ python -c 'from TTS.utils.downloaders import download_ljspeech; download_ljspeech("../recipes/ljspeech/");'
```
1. Define `train.py`.
```python
import os
# GlowTTSConfig: all model related values for training, validating and testing.
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
# BaseDatasetConfig: defines name, formatter and path of the dataset.
from TTS.tts.configs.shared_config import BaseDatasetConfig
# init_training: Initialize and setup the training environment.
# Trainer: Where the ✨️ happens.
# TrainingArgs: Defines the set of arguments of the Trainer.
from TTS.trainer import init_training, Trainer, TrainingArgs
# we use the same path as this script as our training folder.
output_path = os.path.dirname(os.path.abspath(__file__))
# set LJSpeech as our target dataset and define its path so that the Trainer knows what data formatter it needs.
dataset_config = BaseDatasetConfig(name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/"))
# Configure the model. Every config class inherits the BaseTTSConfig to have all the fields defined for the Trainer.
config = GlowTTSConfig(
batch_size=32,
eval_batch_size=16,
num_loader_workers=4,
num_eval_loader_workers=4,
run_eval=True,
test_delay_epochs=-1,
epochs=1000,
text_cleaner="english_cleaners",
use_phonemes=False,
phoneme_language="en-us",
phoneme_cache_path=os.path.join(output_path, "phoneme_cache"),
print_step=25,
print_eval=True,
mixed_precision=False,
output_path=output_path,
datasets=[dataset_config]
)
# initialize the audio processor used for feature extraction and audio I/O.
# It is mainly used by the dataloader and the training loggers.
ap = AudioProcessor(**config.audio.to_dict())
# load a list of training samples
# Each sample is a list of ```[text, audio_file_path, speaker_name]```
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
# initialize the model
# Models only takes the config object as input.
model = GlowTTS(config)
# Initiate the Trainer.
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc.
trainer = Trainer(
TrainingArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
training_assets={"audio_processor": ap},
)
# And kick it 🚀
trainer.fit()
```{literalinclude} ../../recipes/ljspeech/glow_tts/train_glowtts.py
```
2. Run the script.
@ -154,58 +95,7 @@ We still support running training from CLI like in the old days. The same traini
## Training a `vocoder` Model
```python
import os
from TTS.trainer import Trainer, TrainingArgs
from TTS.utils.audio import AudioProcessor
from TTS.vocoder.configs import HifiganConfig
from TTS.vocoder.datasets.preprocess import load_wav_data
from TTS.vocoder.models.gan import GAN
output_path = os.path.dirname(os.path.abspath(__file__))
config = HifiganConfig(
batch_size=32,
eval_batch_size=16,
num_loader_workers=4,
num_eval_loader_workers=4,
run_eval=True,
test_delay_epochs=5,
epochs=1000,
seq_len=8192,
pad_short=2000,
use_noise_augment=True,
eval_split_size=10,
print_step=25,
print_eval=False,
mixed_precision=False,
lr_gen=1e-4,
lr_disc=1e-4,
data_path=os.path.join(output_path, "../LJSpeech-1.1/wavs/"),
output_path=output_path,
)
# init audio processor
ap = AudioProcessor(**config.audio.to_dict())
# load training samples
eval_samples, train_samples = load_wav_data(config.data_path, config.eval_split_size)
# init model
model = GAN(config)
# init the trainer and 🚀
trainer = Trainer(
TrainingArgs(),
config,
output_path,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
training_assets={"audio_processor": ap},
)
trainer.fit()
```{literalinclude} ../../recipes/ljspeech/hifigan/train_hifigan.py
```
❗️ Note that you can also use ```train_vocoder.py``` as the ```tts``` models above.

View File

@ -1,16 +1,30 @@
import os
# Trainer: Where the ✨️ happens.
# TrainingArgs: Defines the set of arguments of the Trainer.
from TTS.trainer import Trainer, TrainingArgs
# GlowTTSConfig: all model related values for training, validating and testing.
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
# BaseDatasetConfig: defines name, formatter and path of the dataset.
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.glow_tts import GlowTTS
from TTS.utils.audio import AudioProcessor
# we use the same path as this script as our training folder.
output_path = os.path.dirname(os.path.abspath(__file__))
# DEFINE DATASET CONFIG
# Set LJSpeech as our target dataset and define its path.
# You can also use a simple Dict to define the dataset and pass it to your custom formatter.
dataset_config = BaseDatasetConfig(
name="ljspeech", meta_file_train="metadata.csv", path=os.path.join(output_path, "../LJSpeech-1.1/")
)
# INITIALIZE THE TRAINING CONFIGURATION
# Configure the model. Every config class inherits the BaseTTSConfig.
config = GlowTTSConfig(
batch_size=32,
eval_batch_size=16,
@ -30,16 +44,27 @@ config = GlowTTSConfig(
datasets=[dataset_config],
)
# init audio processor
# INITIALIZE THE AUDIO PROCESSOR
# Audio processor is used for feature extraction and audio I/O.
# It mainly serves to the dataloader and the training loggers.
ap = AudioProcessor(**config.audio.to_dict())
# load training samples
# LOAD DATA SAMPLES
# Each sample is a list of ```[text, audio_file_path, speaker_name]```
# You can define your custom sample loader returning the list of samples.
# Or define your custom formatter and pass it to the `load_tts_samples`.
# Check `TTS.tts.datasets.load_tts_samples` for more details.
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True)
# init model
model = GlowTTS(config)
# INITIALIZE THE MODEL
# Models take a config object and a speaker manager as input
# Config defines the details of the model like the number of layers, the size of the embedding, etc.
# Speaker manager is used by multi-speaker models.
model = GlowTTS(config, speaker_manager=None)
# init the trainer and 🚀
# INITIALIZE THE TRAINER
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc.
trainer = Trainer(
TrainingArgs(),
config,
@ -47,6 +72,8 @@ trainer = Trainer(
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
training_assets={"audio_processor": ap},
training_assets={"audio_processor": ap}, # assets are objetcs used by the models but not class members.
)
# AND... 3,2,1... 🚀
trainer.fit()

View File

@ -9,11 +9,24 @@ from TTS.tts.models.glow_tts import GlowTTS
from TTS.tts.utils.speakers import SpeakerManager
from TTS.utils.audio import AudioProcessor
# set experiment paths
output_path = os.path.dirname(os.path.abspath(__file__))
dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=os.path.join(output_path, "../VCTK/"))
dataset_path = os.path.join(output_path, "../VCTK/")
audio_config = BaseAudioConfig(sample_rate=22050, do_trim_silence=True, trim_db=23.0)
# download the dataset if not downloaded
if not os.path.exists(dataset_path):
from TTS.utils.downloaders import download_vctk
download_vctk(dataset_path)
# define dataset config
dataset_config = BaseDatasetConfig(name="vctk", meta_file_train="", path=dataset_path)
# define audio config
# ❗ resample the dataset externally using `TTS/bin/resample.py` and set `resample=False` for faster training
audio_config = BaseAudioConfig(sample_rate=22050, resample=True, do_trim_silence=True, trim_db=23.0)
# define model config
config = GlowTTSConfig(
batch_size=64,
eval_batch_size=16,