mirror of https://github.com/coqui-ai/TTS.git
Merge fix and eval split as argparse
commit
2e5baffa9c
|
@ -18,8 +18,8 @@ jobs:
|
|||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: [3.6, 3.7, 3.8]
|
||||
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
experimental: [false]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/cache@v1
|
||||
|
|
|
@ -124,7 +124,9 @@ version.py
|
|||
# jupyter dummy files
|
||||
core
|
||||
|
||||
# files used internally fro dev, test etc.
|
||||
tests/outputs/*
|
||||
tests/train_outputs/*
|
||||
TODO.txt
|
||||
.vscode/*
|
||||
data/*
|
||||
|
@ -132,7 +134,22 @@ notebooks/data/*
|
|||
TTS/tts/layers/glow_tts/monotonic_align/core.c
|
||||
.vscode-upload.json
|
||||
temp_build/*
|
||||
recipes/*
|
||||
|
||||
# nohup logs
|
||||
recipes/WIP/*
|
||||
recipes/ljspeech/LJSpeech-1.1/*
|
||||
events.out*
|
||||
old_configs/*
|
||||
model_importers/*
|
||||
model_profiling/*
|
||||
docs/source/TODO/*
|
||||
docs/source/models/*
|
||||
.noseids
|
||||
.dccache
|
||||
log.txt
|
||||
umap.png
|
||||
*.out
|
||||
SocialMedia.txt
|
||||
output.wav
|
||||
tts_output.wav
|
||||
deps.json
|
||||
speakers.json
|
||||
internal/*
|
|
@ -9,4 +9,19 @@ repos:
|
|||
rev: 20.8b1
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python3
|
||||
language_version: python3
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.8.0
|
||||
hooks:
|
||||
- id: isort
|
||||
name: isort (python)
|
||||
- id: isort
|
||||
name: isort (cython)
|
||||
types: [cython]
|
||||
- id: isort
|
||||
name: isort (pyi)
|
||||
types: [pyi]
|
||||
- repo: https://github.com/pycqa/pylint
|
||||
rev: v2.8.2
|
||||
hooks:
|
||||
- id: pylint
|
||||
|
|
|
@ -61,6 +61,9 @@ confidence=
|
|||
# no Warning level messages displayed, use "--disable=all --enable=classes
|
||||
# --disable=W".
|
||||
disable=missing-docstring,
|
||||
too-many-public-methods,
|
||||
too-many-lines,
|
||||
bare-except,
|
||||
line-too-long,
|
||||
fixme,
|
||||
wrong-import-order,
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
# .readthedocs.yml
|
||||
# Read the Docs configuration file
|
||||
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
|
||||
|
||||
# Required
|
||||
version: 2
|
||||
|
||||
# Build documentation in the docs/ directory with Sphinx
|
||||
sphinx:
|
||||
builder: html
|
||||
configuration: docs/source/conf.py
|
||||
|
||||
# Optionally set the version of Python and requirements required to build your docs
|
||||
python:
|
||||
version: 3.7
|
||||
install:
|
||||
- requirements: docs/requirements.txt
|
||||
- requirements: requirements.txt
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
Welcome to the 🐸TTS!
|
||||
|
||||
This repository is governed by the Contributor Covenant Code of Conduct - [CODE_OF_CONDUCT.md](CODE_OF_CONDUCT.md).
|
||||
This repository is governed by [the Contributor Covenant Code of Conduct](https://github.com/coqui-ai/TTS/blob/main/CODE_OF_CONDUCT.md).
|
||||
|
||||
## Where to start.
|
||||
We welcome everyone who likes to contribute to 🐸TTS.
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
include README.md
|
||||
include LICENSE.txt
|
||||
include requirements.*.txt
|
||||
include requirements.txt
|
||||
include TTS/VERSION
|
||||
recursive-include TTS *.json
|
||||
recursive-include TTS *.html
|
||||
|
|
31
Makefile
31
Makefile
|
@ -6,15 +6,9 @@ help:
|
|||
|
||||
target_dirs := tests TTS notebooks
|
||||
|
||||
system-deps: ## install linux system deps
|
||||
sudo apt-get install -y libsndfile1-dev
|
||||
|
||||
dev-deps: ## install development deps
|
||||
pip install -r requirements.dev.txt
|
||||
pip install -r requirements.tf.txt
|
||||
|
||||
deps: ## install 🐸 requirements.
|
||||
pip install -r requirements.txt
|
||||
test_all: ## run tests and don't stop on an error.
|
||||
nosetests --with-cov -cov --cover-erase --cover-package TTS tests --nologcapture --with-id
|
||||
./run_bash_tests.sh
|
||||
|
||||
test: ## run tests.
|
||||
nosetests -x --with-cov -cov --cover-erase --cover-package TTS tests --nologcapture --with-id
|
||||
|
@ -30,5 +24,24 @@ style: ## update code style.
|
|||
lint: ## run pylint linter.
|
||||
pylint ${target_dirs}
|
||||
|
||||
system-deps: ## install linux system deps
|
||||
sudo apt-get install -y libsndfile1-dev
|
||||
|
||||
dev-deps: ## install development deps
|
||||
pip install -r requirements.dev.txt
|
||||
pip install -r requirements.tf.txt
|
||||
|
||||
doc-deps: ## install docs dependencies
|
||||
pip install -r docs/requirements.txt
|
||||
|
||||
build-docs: ## build the docs
|
||||
cd docs && make clean && make build
|
||||
|
||||
hub-deps: ## install deps for torch hub use
|
||||
pip install -r requirements.hub.txt
|
||||
|
||||
deps: ## install 🐸 requirements.
|
||||
pip install -r requirements.txt
|
||||
|
||||
install: ## install 🐸 TTS for development.
|
||||
pip install -e .[all]
|
||||
|
|
208
README.md
208
README.md
|
@ -3,8 +3,9 @@
|
|||
🐸TTS is a library for advanced Text-to-Speech generation. It's built on the latest research, was designed to achieve the best trade-off among ease-of-training, speed and quality.
|
||||
🐸TTS comes with [pretrained models](https://github.com/coqui-ai/TTS/wiki/Released-Models), tools for measuring dataset quality and already used in **20+ languages** for products and research projects.
|
||||
|
||||
[]()
|
||||
[](https://github.com/coqui-ai/TTS/actions)
|
||||
[](https://opensource.org/licenses/MPL-2.0)
|
||||
[](https://tts.readthedocs.io/en/latest/)
|
||||
[](https://badge.fury.io/py/TTS)
|
||||
[](https://github.com/coqui-ai/TTS/blob/master/CODE_OF_CONDUCT.md)
|
||||
[](https://pepy.tech/project/tts)
|
||||
|
@ -16,20 +17,17 @@
|
|||
|
||||
📢 [English Voice Samples](https://erogol.github.io/ddc-samples/) and [SoundCloud playlist](https://soundcloud.com/user-565970875/pocket-article-wavernn-and-tacotron2)
|
||||
|
||||
👩🏽🍳 [TTS training recipes](https://github.com/erogol/TTS_recipes)
|
||||
|
||||
📄 [Text-to-Speech paper collection](https://github.com/erogol/TTS-papers)
|
||||
|
||||
## 💬 Where to ask questions
|
||||
Please use our dedicated channels for questions and discussion. Help is much more valuable if it's shared publicly, so that more people can benefit from it.
|
||||
Please use our dedicated channels for questions and discussion. Help is much more valuable if it's shared publicly so that more people can benefit from it.
|
||||
|
||||
| Type | Platforms |
|
||||
| ------------------------------- | --------------------------------------- |
|
||||
| 🚨 **Bug Reports** | [GitHub Issue Tracker] |
|
||||
| ❔ **FAQ** | [TTS/Wiki](https://github.com/coqui-ai/TTS/wiki/FAQ) |
|
||||
| 🎁 **Feature Requests & Ideas** | [GitHub Issue Tracker] |
|
||||
| 👩💻 **Usage Questions** | [Github Discussions] |
|
||||
| 🗯 **General Discussion** | [Github Discussions] or [Gitter Room]|
|
||||
| 🗯 **General Discussion** | [Github Discussions] or [Gitter Room] |
|
||||
|
||||
[github issue tracker]: https://github.com/coqui-ai/tts/issues
|
||||
[github discussions]: https://github.com/coqui-ai/TTS/discussions
|
||||
|
@ -40,14 +38,11 @@ Please use our dedicated channels for questions and discussion. Help is much mor
|
|||
## 🔗 Links and Resources
|
||||
| Type | Links |
|
||||
| ------------------------------- | --------------------------------------- |
|
||||
| 💼 **Documentation** | [ReadTheDocs](https://tts.readthedocs.io/en/latest/)
|
||||
| 💾 **Installation** | [TTS/README.md](https://github.com/coqui-ai/TTS/tree/dev#install-tts)|
|
||||
| 👩💻 **Contributing** | [CONTRIBUTING.md](https://github.com/coqui-ai/TTS/blob/main/CONTRIBUTING.md)|
|
||||
| 📌 **Road Map** | [Main Development Plans](https://github.com/coqui-ai/TTS/issues/378)
|
||||
| 👩🏾🏫 **Tutorials and Examples** | [TTS/Wiki](https://github.com/coqui-ai/TTS/wiki/%F0%9F%90%B8-TTS-Notebooks,-Examples-and-Tutorials) |
|
||||
| 🚀 **Released Models** | [TTS Releases](https://github.com/coqui-ai/TTS/releases) and [Experimental Models](https://github.com/coqui-ai/TTS/wiki/Experimental-Released-Models)|
|
||||
| 🖥️ **Demo Server** | [TTS/server](https://github.com/coqui-ai/TTS/tree/master/TTS/server)|
|
||||
| 🤖 **Synthesize speech** | [TTS/README.md](https://github.com/coqui-ai/TTS#example-synthesizing-speech-on-terminal-using-the-released-models)|
|
||||
| 🛠️ **Implementing a New Model** | [TTS/Wiki](https://github.com/coqui-ai/TTS/wiki/Implementing-a-New-Model-in-%F0%9F%90%B8TTS)|
|
||||
|
||||
## 🥇 TTS Performance
|
||||
<p align="center"><img src="https://raw.githubusercontent.com/coqui-ai/TTS/main/images/TTS-performance.png" width="800" /></p>
|
||||
|
@ -56,20 +51,19 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models
|
|||
<!-- [Details...](https://github.com/coqui-ai/TTS/wiki/Mean-Opinion-Score-Results) -->
|
||||
|
||||
## Features
|
||||
- High performance Deep Learning models for Text2Speech tasks.
|
||||
- High-performance Deep Learning models for Text2Speech tasks.
|
||||
- Text2Spec models (Tacotron, Tacotron2, Glow-TTS, SpeedySpeech).
|
||||
- Speaker Encoder to compute speaker embeddings efficiently.
|
||||
- Vocoder models (MelGAN, Multiband-MelGAN, GAN-TTS, ParallelWaveGAN, WaveGrad, WaveRNN)
|
||||
- Fast and efficient model training.
|
||||
- Detailed training logs on console and Tensorboard.
|
||||
- Support for multi-speaker TTS.
|
||||
- Efficient Multi-GPUs training.
|
||||
- Detailed training logs on the terminal and Tensorboard.
|
||||
- Support for Multi-speaker TTS.
|
||||
- Efficient, flexible, lightweight but feature complete `Trainer API`.
|
||||
- Ability to convert PyTorch models to Tensorflow 2.0 and TFLite for inference.
|
||||
- Released models in PyTorch, Tensorflow and TFLite.
|
||||
- Released and read-to-use models.
|
||||
- Tools to curate Text2Speech datasets under```dataset_analysis```.
|
||||
- Demo server for model testing.
|
||||
- Notebooks for extensive model benchmarking.
|
||||
- Modular (but not too much) code base enabling easy testing for new ideas.
|
||||
- Utilities to use and test your models.
|
||||
- Modular (but not too much) code base enabling easy implementation of new ideas.
|
||||
|
||||
## Implemented Models
|
||||
### Text-to-Spectrogram
|
||||
|
@ -98,8 +92,9 @@ Underlined "TTS*" and "Judy*" are 🐸TTS models
|
|||
- WaveRNN: [origin](https://github.com/fatchord/WaveRNN/)
|
||||
- WaveGrad: [paper](https://arxiv.org/abs/2009.00713)
|
||||
- HiFiGAN: [paper](https://arxiv.org/abs/2010.05646)
|
||||
- UnivNet: [paper](https://arxiv.org/abs/2106.07889)
|
||||
|
||||
You can also help us implement more models. Some 🐸TTS related work can be found [here](https://github.com/erogol/TTS-papers).
|
||||
You can also help us implement more models.
|
||||
|
||||
## Install TTS
|
||||
🐸TTS is tested on Ubuntu 18.04 with **python >= 3.6, < 3.9**.
|
||||
|
@ -110,7 +105,7 @@ If you are only interested in [synthesizing speech](https://github.com/coqui-ai/
|
|||
pip install TTS
|
||||
```
|
||||
|
||||
By default this only installs the requirements for PyTorch. To install the tensorflow dependencies as well, use the `tf` extra.
|
||||
By default, this only installs the requirements for PyTorch. To install the tensorflow dependencies as well, use the `tf` extra.
|
||||
|
||||
```bash
|
||||
pip install TTS[tf]
|
||||
|
@ -123,12 +118,6 @@ git clone https://github.com/coqui-ai/TTS
|
|||
pip install -e .[all,dev,notebooks,tf] # Select the relevant extras
|
||||
```
|
||||
|
||||
We use ```espeak-ng``` to convert graphemes to phonemes. You might need to install separately.
|
||||
|
||||
```bash
|
||||
sudo apt-get install espeak-ng
|
||||
```
|
||||
|
||||
If you are on Ubuntu (Debian), you can also run following commands for installation.
|
||||
|
||||
```bash
|
||||
|
@ -137,6 +126,7 @@ $ make install
|
|||
```
|
||||
|
||||
If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](https://stackoverflow.com/questions/66726331/how-can-i-run-mozilla-tts-coqui-tts-training-with-cuda-on-a-windows-system).
|
||||
|
||||
## Directory Structure
|
||||
```
|
||||
|- notebooks/ (Jupyter Notebooks for model evaluation, parameter selection and data analysis.)
|
||||
|
@ -147,6 +137,7 @@ If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](ht
|
|||
|- distribute.py (train your TTS model using Multiple GPUs.)
|
||||
|- compute_statistics.py (compute dataset statistics for normalization.)
|
||||
|- convert*.py (convert target torch model to TF.)
|
||||
|- ...
|
||||
|- tts/ (text to speech models)
|
||||
|- layers/ (model layer definitions)
|
||||
|- models/ (model definitions)
|
||||
|
@ -156,167 +147,4 @@ If you are on Windows, 👑@GuyPaddock wrote installation instructions [here](ht
|
|||
|- (same)
|
||||
|- vocoder/ (Vocoder models.)
|
||||
|- (same)
|
||||
```
|
||||
|
||||
## Sample Model Output
|
||||
Below you see Tacotron model state after 16K iterations with batch-size 32 with LJSpeech dataset.
|
||||
|
||||
> "Recent research at Harvard has shown meditating for as little as 8 weeks can actually increase the grey matter in the parts of the brain responsible for emotional regulation and learning."
|
||||
|
||||
Audio examples: [soundcloud](https://soundcloud.com/user-565970875/pocket-article-wavernn-and-tacotron2)
|
||||
|
||||
<img src="images/example_model_output.png?raw=true" alt="example_output" width="400"/>
|
||||
|
||||
## Datasets and Data-Loading
|
||||
🐸TTS provides a generic dataloader easy to use for your custom dataset.
|
||||
You just need to write a simple function to format the dataset. Check ```datasets/preprocess.py``` to see some examples.
|
||||
After that, you need to set ```dataset``` fields in ```config.json```.
|
||||
|
||||
Some of the public datasets that we successfully applied 🐸TTS:
|
||||
|
||||
- [LJ Speech](https://keithito.com/LJ-Speech-Dataset/)
|
||||
- [Nancy](http://www.cstr.ed.ac.uk/projects/blizzard/2011/lessac_blizzard2011/)
|
||||
- [TWEB](https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset)
|
||||
- [M-AI-Labs](http://www.caito.de/2019/01/the-m-ailabs-speech-dataset/)
|
||||
- [LibriTTS](https://openslr.org/60/)
|
||||
- [Spanish](https://drive.google.com/file/d/1Sm_zyBo67XHkiFhcRSQ4YaHPYM0slO_e/view?usp=sharing) - thx! @carlfm01
|
||||
|
||||
## Example: Synthesizing Speech on Terminal Using the Released Models.
|
||||
<img src="images/tts_cli.gif"/>
|
||||
|
||||
After the installation, 🐸TTS provides a CLI interface for synthesizing speech using pre-trained models. You can either use your own model or the release models under 🐸TTS.
|
||||
|
||||
Listing released 🐸TTS models.
|
||||
|
||||
```bash
|
||||
tts --list_models
|
||||
```
|
||||
|
||||
Run a TTS model, from the release models list, with its default vocoder. (Simply copy and paste the full model names from the list as arguments for the command below.)
|
||||
|
||||
```bash
|
||||
tts --text "Text for TTS" \
|
||||
--model_name "<type>/<language>/<dataset>/<model_name>" \
|
||||
--out_path folder/to/save/output.wav
|
||||
```
|
||||
|
||||
Run a tts and a vocoder model from the released model list. Note that not every vocoder is compatible with every TTS model.
|
||||
|
||||
```bash
|
||||
tts --text "Text for TTS" \
|
||||
--model_name "<type>/<language>/<dataset>/<model_name>" \
|
||||
--vocoder_name "<type>/<language>/<dataset>/<model_name>" \
|
||||
--out_path folder/to/save/output.wav
|
||||
```
|
||||
|
||||
Run your own TTS model (Using Griffin-Lim Vocoder)
|
||||
|
||||
```bash
|
||||
tts --text "Text for TTS" \
|
||||
--model_path path/to/model.pth.tar \
|
||||
--config_path path/to/config.json \
|
||||
--out_path folder/to/save/output.wav
|
||||
```
|
||||
|
||||
Run your own TTS and Vocoder models
|
||||
|
||||
```bash
|
||||
tts --text "Text for TTS" \
|
||||
--config_path path/to/config.json \
|
||||
--model_path path/to/model.pth.tar \
|
||||
--out_path folder/to/save/output.wav \
|
||||
--vocoder_path path/to/vocoder.pth.tar \
|
||||
--vocoder_config_path path/to/vocoder_config.json
|
||||
```
|
||||
|
||||
Run a multi-speaker TTS model from the released models list.
|
||||
|
||||
```bash
|
||||
tts --model_name "<type>/<language>/<dataset>/<model_name>" --list_speaker_idxs # list the possible speaker IDs.
|
||||
tts --text "Text for TTS." --out_path output/path/speech.wav --model_name "<language>/<dataset>/<model_name>" --speaker_idx "<speaker_id>"
|
||||
```
|
||||
|
||||
**Note:** You can use ```./TTS/bin/synthesize.py``` if you prefer running ```tts``` from the TTS project folder.
|
||||
|
||||
## Example: Using the Demo Server for Synthesizing Speech
|
||||
|
||||
<!-- <img src="https://raw.githubusercontent.com/coqui-ai/TTS/main/images/demo_server.gif" height="56"/> -->
|
||||
<img src="images/demo_server.gif"/>
|
||||
|
||||
You can boot up a demo 🐸TTS server to run inference with your models. Note that the server is not optimized for performance
|
||||
but gives you an easy way to interact with the models.
|
||||
|
||||
The demo server provides pretty much the same interface as the CLI command.
|
||||
|
||||
```bash
|
||||
tts-server -h # see the help
|
||||
tts-server --list_models # list the available models.
|
||||
```
|
||||
|
||||
Run a TTS model, from the release models list, with its default vocoder.
|
||||
If the model you choose is a multi-speaker TTS model, you can select different speakers on the Web interface and synthesize
|
||||
speech.
|
||||
|
||||
```bash
|
||||
tts-server --model_name "<type>/<language>/<dataset>/<model_name>"
|
||||
```
|
||||
|
||||
Run a TTS and a vocoder model from the released model list. Note that not every vocoder is compatible with every TTS model.
|
||||
|
||||
```bash
|
||||
tts-server --model_name "<type>/<language>/<dataset>/<model_name>" \
|
||||
--vocoder_name "<type>/<language>/<dataset>/<model_name>"
|
||||
```
|
||||
|
||||
|
||||
## Example: Training and Fine-tuning LJ-Speech Dataset
|
||||
Here you can find a [CoLab](https://gist.github.com/erogol/97516ad65b44dbddb8cd694953187c5b) notebook for a hands-on example, training LJSpeech. Or you can manually follow the guideline below.
|
||||
|
||||
To start with, split ```metadata.csv``` into train and validation subsets respectively ```metadata_train.csv``` and ```metadata_val.csv```. Note that for text-to-speech, validation performance might be misleading since the loss value does not directly measure the voice quality to the human ear and it also does not measure the attention module performance. Therefore, running the model with new sentences and listening to the results is the best way to go.
|
||||
|
||||
```
|
||||
shuf metadata.csv > metadata_shuf.csv
|
||||
head -n 12000 metadata_shuf.csv > metadata_train.csv
|
||||
tail -n 1100 metadata_shuf.csv > metadata_val.csv
|
||||
```
|
||||
|
||||
To train a new model, you need to define your own ```config.json``` to define model details, trainin configuration and more (check the examples). Then call the corressponding train script.
|
||||
|
||||
For instance, in order to train a tacotron or tacotron2 model on LJSpeech dataset, follow these steps.
|
||||
|
||||
```bash
|
||||
python TTS/bin/train_tacotron.py --config_path TTS/tts/configs/config.json
|
||||
```
|
||||
|
||||
To fine-tune a model, use ```--restore_path```.
|
||||
|
||||
```bash
|
||||
python TTS/bin/train_tacotron.py --config_path TTS/tts/configs/config.json --restore_path /path/to/your/model.pth.tar
|
||||
```
|
||||
|
||||
To continue an old training run, use ```--continue_path```.
|
||||
|
||||
```bash
|
||||
python TTS/bin/train_tacotron.py --continue_path /path/to/your/run_folder/
|
||||
```
|
||||
|
||||
For multi-GPU training, call ```distribute.py```. It runs any provided train script in multi-GPU setting.
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES="0,1,4" python TTS/bin/distribute.py --script train_tacotron.py --config_path TTS/tts/configs/config.json
|
||||
```
|
||||
|
||||
Each run creates a new output folder accomodating used ```config.json```, model checkpoints and tensorboard logs.
|
||||
|
||||
In case of any error or intercepted execution, if there is no checkpoint yet under the output folder, the whole folder is going to be removed.
|
||||
|
||||
You can also enjoy Tensorboard, if you point Tensorboard argument```--logdir``` to the experiment folder.
|
||||
|
||||
## [Contribution guidelines](https://github.com/coqui-ai/TTS/blob/main/CONTRIBUTING.md)
|
||||
### Acknowledgement
|
||||
- https://github.com/keithito/tacotron (Dataset pre-processing)
|
||||
- https://github.com/r9y9/tacotron_pytorch (Initial Tacotron architecture)
|
||||
- https://github.com/kan-bayashi/ParallelWaveGAN (GAN based vocoder library)
|
||||
- https://github.com/jaywalnut310/glow-tts (Original Glow-TTS implementation)
|
||||
- https://github.com/fatchord/WaveRNN/ (Original WaveRNN implementation)
|
||||
- https://arxiv.org/abs/2010.05646 (Original HiFiGAN implementation)
|
||||
```
|
|
@ -4,10 +4,9 @@
|
|||
"ek1":{
|
||||
"tacotron2": {
|
||||
"description": "EK1 en-rp tacotron2 by NMStoker",
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/tts_models--en--ek1--tacotron2.zip",
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--ek1--tacotron2.zip",
|
||||
"default_vocoder": "vocoder_models/en/ek1/wavegrad",
|
||||
"commit": "c802255",
|
||||
"needs_phonemizer": true
|
||||
"commit": "c802255"
|
||||
}
|
||||
},
|
||||
"ljspeech":{
|
||||
|
@ -18,8 +17,7 @@
|
|||
"commit": "bae2ad0f",
|
||||
"author": "Eren Gölge @erogol",
|
||||
"license": "",
|
||||
"contact":"egolge@coqui.com",
|
||||
"needs_phonemizer": false
|
||||
"contact":"egolge@coqui.com"
|
||||
},
|
||||
"glow-tts":{
|
||||
"description": "",
|
||||
|
@ -29,8 +27,7 @@
|
|||
"commit": "",
|
||||
"author": "Eren Gölge @erogol",
|
||||
"license": "MPL",
|
||||
"contact":"egolge@coqui.com",
|
||||
"needs_phonemizer": true
|
||||
"contact":"egolge@coqui.com"
|
||||
},
|
||||
"tacotron2-DCA": {
|
||||
"description": "",
|
||||
|
@ -39,30 +36,27 @@
|
|||
"commit": "",
|
||||
"author": "Eren Gölge @erogol",
|
||||
"license": "MPL",
|
||||
"contact":"egolge@coqui.com",
|
||||
"needs_phonemizer": true
|
||||
"contact":"egolge@coqui.com"
|
||||
},
|
||||
"speedy-speech-wn":{
|
||||
"description": "Speedy Speech model with wavenet decoder.",
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.10/tts_models--en--ljspeech--speedy-speech-wn.zip",
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--ljspeech--speedy-speech-wn.zip",
|
||||
"default_vocoder": "vocoder_models/en/ljspeech/multiband-melgan",
|
||||
"commit": "77b6145",
|
||||
"author": "Eren Gölge @erogol",
|
||||
"license": "MPL",
|
||||
"contact":"egolge@coqui.com",
|
||||
"needs_phonemizer": true
|
||||
"contact":"egolge@coqui.com"
|
||||
}
|
||||
},
|
||||
"vctk":{
|
||||
"sc-glow-tts": {
|
||||
"description": "Multi-Speaker Transformers based SC-Glow model from https://arxiv.org/abs/2104.05557.",
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.12/tts_models--en--vctk--sc-glowtts-transformer.zip",
|
||||
"default_vocoder": null,
|
||||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.1.0/tts_models--en--vctk--sc-glow-tts.zip",
|
||||
"default_vocoder": "vocoder_models/en/vctk/hifigan_v2",
|
||||
"commit": "b531fa69",
|
||||
"author": "Edresson Casanova",
|
||||
"license": "",
|
||||
"contact":"",
|
||||
"needs_phonemizer": true
|
||||
"contact":""
|
||||
|
||||
|
||||
}
|
||||
|
@ -75,8 +69,7 @@
|
|||
"commit": "bae2ad0f",
|
||||
"author": "Eren Gölge @erogol",
|
||||
"license": "",
|
||||
"contact":"egolge@coqui.com",
|
||||
"needs_phonemizer": true
|
||||
"contact":"egolge@coqui.com"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -88,8 +81,7 @@
|
|||
"commit": "",
|
||||
"author": "Eren Gölge @erogol",
|
||||
"license": "MPL",
|
||||
"contact":"egolge@coqui.com",
|
||||
"needs_phonemizer": true
|
||||
"contact":"egolge@coqui.com"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -101,8 +93,7 @@
|
|||
"commit": "",
|
||||
"author": "Eren Gölge @erogol",
|
||||
"license": "MPL",
|
||||
"contact":"egolge@coqui.com",
|
||||
"needs_phonemizer": true
|
||||
"contact":"egolge@coqui.com"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -122,8 +113,7 @@
|
|||
"author": "@r-dh",
|
||||
"default_vocoder": "vocoder_models/nl/mai/parallel-wavegan",
|
||||
"stats_file": null,
|
||||
"commit": "540d811",
|
||||
"needs_phonemizer": true
|
||||
"commit": "540d811"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -134,8 +124,7 @@
|
|||
"author": "@erogol",
|
||||
"default_vocoder": "vocoder_models/universal/libri-tts/fullband-melgan",
|
||||
"license":"",
|
||||
"contact": "egolge@coqui.com",
|
||||
"needs_phonemizer": true
|
||||
"contact": "egolge@coqui.com"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -145,8 +134,7 @@
|
|||
"github_rls_url": "https://github.com/coqui-ai/TTS/releases/download/v0.0.11/tts_models--de--thorsten--tacotron2-DCA.zip",
|
||||
"default_vocoder": "vocoder_models/de/thorsten/wavegrad",
|
||||
"author": "@thorstenMueller",
|
||||
"commit": "unknown",
|
||||
"needs_phonemizer": true
|
||||
"commit": "unknown"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -157,8 +145,7 @@
|
|||
"default_vocoder": "vocoder_models/universal/libri-tts/wavegrad",
|
||||
"description": "Tacotron2 with Double Decoder Consistency trained with Kokoro Speech Dataset.",
|
||||
"author": "@kaiidams",
|
||||
"commit": "401fbd89",
|
||||
"needs_phonemizer": false
|
||||
"commit": "401fbd89"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1 +1 @@
|
|||
0.0.16
|
||||
0.1.2
|
|
@ -8,8 +8,8 @@ import torch
|
|||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.utils.io import load_checkpoint
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
@ -75,21 +75,21 @@ Example run:
|
|||
# load the model
|
||||
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
||||
# TODO: handle multi-speaker
|
||||
model = setup_model(num_chars, num_speakers=0, c=C)
|
||||
model = setup_model(C)
|
||||
model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda)
|
||||
model.eval()
|
||||
|
||||
# data loader
|
||||
preprocessor = importlib.import_module("TTS.tts.datasets.preprocess")
|
||||
preprocessor = importlib.import_module("TTS.tts.datasets.formatters")
|
||||
preprocessor = getattr(preprocessor, args.dataset)
|
||||
meta_data = preprocessor(args.data_path, args.dataset_metafile)
|
||||
dataset = MyDataset(
|
||||
dataset = TTSDataset(
|
||||
model.decoder.r,
|
||||
C.text_cleaner,
|
||||
compute_linear_spec=False,
|
||||
ap=ap,
|
||||
meta_data=meta_data,
|
||||
tp=C.characters if "characters" in C.keys() else None,
|
||||
characters=C.characters if "characters" in C.keys() else None,
|
||||
add_blank=C["add_blank"] if "add_blank" in C.keys() else False,
|
||||
use_phonemes=C.use_phonemes,
|
||||
phoneme_cache_path=C.phoneme_cache_path,
|
||||
|
|
|
@ -1,12 +1,17 @@
|
|||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from TTS.config import load_config
|
||||
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.config import BaseDatasetConfig, load_config
|
||||
from TTS.speaker_encoder.utils.generic_utils import setup_model
|
||||
from TTS.tts.datasets import load_meta_data
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Compute embedding vectors for each wav file in a dataset.'
|
||||
)
|
||||
|
@ -24,11 +29,13 @@ parser.add_argument(
|
|||
)
|
||||
parser.add_argument("output_path", type=str, help="path for output speakers.json and/or speakers.npy.")
|
||||
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True)
|
||||
parser.add_argument("--eval", type=bool, help="compute eval.", default=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
c_dataset = load_config(args.config_dataset_path)
|
||||
|
||||
train_files, dev_files = load_meta_data(c_dataset.datasets, eval_split=True, ignore_generated_eval=True)
|
||||
train_files, dev_files = load_meta_data(c_dataset.datasets, eval_split=args.eval, ignore_generated_eval=True)
|
||||
wav_files = train_files + dev_files
|
||||
|
||||
speaker_manager = SpeakerManager(encoder_model_path=args.model_path, encoder_config_path=args.config_path, use_cuda=args.use_cuda)
|
||||
|
@ -43,7 +50,7 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
|
|||
speaker_name = None
|
||||
|
||||
# extract the embedding
|
||||
embedd = speaker_manager.compute_x_vector_from_clip(wav_file)
|
||||
embedd = speaker_manager.compute_d_vector_from_clip(wav_file)
|
||||
|
||||
# create speaker_mapping if target dataset is defined
|
||||
wav_file_name = os.path.basename(wav_file)
|
||||
|
|
|
@ -10,7 +10,7 @@ from tqdm import tqdm
|
|||
|
||||
# from TTS.utils.io import load_config
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.datasets import load_meta_data
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
|
@ -77,7 +77,7 @@ def main():
|
|||
print(f" > Avg mel spec mean: {mel_mean.mean()}")
|
||||
print(f" > Avg mel spec scale: {mel_scale.mean()}")
|
||||
print(f" > Avg linear spec mean: {linear_mean.mean()}")
|
||||
print(f" > Avg lienar spec scale: {linear_scale.mean()}")
|
||||
print(f" > Avg linear spec scale: {linear_scale.mean()}")
|
||||
|
||||
# set default config values for mean-var scaling
|
||||
CONFIG.audio.stats_path = output_file_path
|
||||
|
|
|
@ -8,10 +8,10 @@ import numpy as np
|
|||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.tf.models.tacotron2 import Tacotron2
|
||||
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
|
||||
from TTS.tts.tf.utils.generic_utils import save_checkpoint
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.utils.text.symbols import phonemes, symbols
|
||||
from TTS.utils.io import load_config
|
||||
|
||||
|
@ -31,18 +31,18 @@ c = load_config(config_path)
|
|||
num_speakers = 0
|
||||
|
||||
# init torch model
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
model = setup_model(num_chars, num_speakers, c)
|
||||
model = setup_model(c)
|
||||
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
|
||||
state_dict = checkpoint["model"]
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# init tf model
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
model_tf = Tacotron2(
|
||||
num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=model.decoder.r,
|
||||
postnet_output_dim=c.audio["num_mels"],
|
||||
out_channels=c.audio["num_mels"],
|
||||
decoder_output_dim=c.audio["num_mels"],
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
|
|
|
@ -1,47 +1,38 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from TTS.trainer import TrainingArgs
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Call train.py as a new process and pass command arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser = TrainingArgs().init_argparse(arg_prefix="")
|
||||
parser.add_argument("--script", type=str, help="Target training script to distibute.")
|
||||
parser.add_argument(
|
||||
"--continue_path",
|
||||
type=str,
|
||||
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
||||
default="",
|
||||
required="--config_path" not in sys.argv,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default=""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in sys.argv
|
||||
)
|
||||
args = parser.parse_args()
|
||||
args, unargs = parser.parse_known_args()
|
||||
|
||||
num_gpus = torch.cuda.device_count()
|
||||
group_id = time.strftime("%Y_%m_%d-%H%M%S")
|
||||
|
||||
# set arguments for train.py
|
||||
folder_path = pathlib.Path(__file__).parent.absolute()
|
||||
command = [os.path.join(folder_path, args.script)]
|
||||
if os.path.exists(os.path.join(folder_path, args.script)):
|
||||
command = [os.path.join(folder_path, args.script)]
|
||||
else:
|
||||
command = [args.script]
|
||||
command.append("--continue_path={}".format(args.continue_path))
|
||||
command.append("--restore_path={}".format(args.restore_path))
|
||||
command.append("--config_path={}".format(args.config_path))
|
||||
command.append("--group_id=group_{}".format(group_id))
|
||||
command += unargs
|
||||
command.append("")
|
||||
|
||||
# run processes
|
||||
|
@ -50,6 +41,7 @@ def main():
|
|||
my_env = os.environ.copy()
|
||||
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
|
||||
command[-1] = "--rank={}".format(i)
|
||||
# prevent stdout for processes with rank != 0
|
||||
stdout = None if i == 0 else open(os.devnull, "w")
|
||||
p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) # pylint: disable=consider-using-with
|
||||
processes.append(p)
|
||||
|
|
|
@ -10,11 +10,10 @@ from torch.utils.data import DataLoader
|
|||
from tqdm import tqdm
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.utils.speakers import parse_speakers
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.tts.datasets import load_meta_data
|
||||
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||
from TTS.tts.models import setup_model
|
||||
from TTS.tts.utils.speakers import get_speaker_manager
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters
|
||||
|
||||
|
@ -22,13 +21,13 @@ use_cuda = torch.cuda.is_available()
|
|||
|
||||
|
||||
def setup_loader(ap, r, verbose=False):
|
||||
dataset = MyDataset(
|
||||
dataset = TTSDataset(
|
||||
r,
|
||||
c.text_cleaner,
|
||||
compute_linear_spec=False,
|
||||
meta_data=meta_data,
|
||||
ap=ap,
|
||||
tp=c.characters if "characters" in c.keys() else None,
|
||||
characters=c.characters if "characters" in c.keys() else None,
|
||||
add_blank=c["add_blank"] if "add_blank" in c.keys() else False,
|
||||
batch_group_size=0,
|
||||
min_seq_len=c.min_seq_len,
|
||||
|
@ -39,7 +38,8 @@ def setup_loader(ap, r, verbose=False):
|
|||
enable_eos_bos=c.enable_eos_bos_chars,
|
||||
use_noise_augment=False,
|
||||
verbose=verbose,
|
||||
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None,
|
||||
speaker_id_mapping=speaker_manager.speaker_ids,
|
||||
d_vector_mapping=speaker_manager.d_vectors if c.use_speaker_embedding and c.use_d_vector_file else None,
|
||||
)
|
||||
|
||||
if c.use_phonemes and c.compute_input_seq_cache:
|
||||
|
@ -78,26 +78,15 @@ def format_data(data):
|
|||
# setup input data
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
speaker_names = data[2]
|
||||
mel_input = data[4]
|
||||
mel_lengths = data[5]
|
||||
item_idx = data[7]
|
||||
attn_mask = data[9]
|
||||
d_vectors = data[8]
|
||||
speaker_ids = data[9]
|
||||
attn_mask = data[10]
|
||||
avg_text_length = torch.mean(text_lengths.float())
|
||||
avg_spec_length = torch.mean(mel_lengths.float())
|
||||
|
||||
if c.use_speaker_embedding:
|
||||
if c.use_external_speaker_embedding_file:
|
||||
speaker_embeddings = data[8]
|
||||
speaker_ids = None
|
||||
else:
|
||||
speaker_ids = [speaker_mapping[speaker_name] for speaker_name in speaker_names]
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
speaker_embeddings = None
|
||||
else:
|
||||
speaker_embeddings = None
|
||||
speaker_ids = None
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
text_input = text_input.cuda(non_blocking=True)
|
||||
|
@ -106,9 +95,8 @@ def format_data(data):
|
|||
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
||||
if speaker_ids is not None:
|
||||
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
||||
if speaker_embeddings is not None:
|
||||
speaker_embeddings = speaker_embeddings.cuda(non_blocking=True)
|
||||
|
||||
if d_vectors is not None:
|
||||
d_vectors = d_vectors.cuda(non_blocking=True)
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.cuda(non_blocking=True)
|
||||
return (
|
||||
|
@ -117,7 +105,7 @@ def format_data(data):
|
|||
mel_input,
|
||||
mel_lengths,
|
||||
speaker_ids,
|
||||
speaker_embeddings,
|
||||
d_vectors,
|
||||
avg_text_length,
|
||||
avg_spec_length,
|
||||
attn_mask,
|
||||
|
@ -134,32 +122,26 @@ def inference(
|
|||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
attn_mask=None,
|
||||
speaker_ids=None,
|
||||
speaker_embeddings=None,
|
||||
d_vectors=None,
|
||||
):
|
||||
if model_name == "glow_tts":
|
||||
mel_input = mel_input.permute(0, 2, 1) # B x D x T
|
||||
speaker_c = None
|
||||
if speaker_ids is not None:
|
||||
speaker_c = speaker_ids
|
||||
elif speaker_embeddings is not None:
|
||||
speaker_c = speaker_embeddings
|
||||
elif d_vectors is not None:
|
||||
speaker_c = d_vectors
|
||||
|
||||
model_output, *_ = model.inference_with_MAS(
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c
|
||||
outputs = model.inference_with_MAS(
|
||||
text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": speaker_c}
|
||||
)
|
||||
model_output = outputs["model_outputs"]
|
||||
model_output = model_output.transpose(1, 2).detach().cpu().numpy()
|
||||
|
||||
elif "tacotron" in model_name:
|
||||
_, postnet_outputs, *_ = model(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
speaker_ids=speaker_ids,
|
||||
speaker_embeddings=speaker_embeddings,
|
||||
)
|
||||
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
|
||||
outputs = model(text_input, text_lengths, mel_input, mel_lengths, aux_input)
|
||||
postnet_outputs = outputs["model_outputs"]
|
||||
# normalize tacotron output
|
||||
if model_name == "tacotron":
|
||||
mel_specs = []
|
||||
|
@ -188,10 +170,10 @@ def extract_spectrograms(
|
|||
mel_input,
|
||||
mel_lengths,
|
||||
speaker_ids,
|
||||
speaker_embeddings,
|
||||
d_vectors,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
attn_mask,
|
||||
item_idx,
|
||||
) = format_data(data)
|
||||
|
||||
|
@ -203,9 +185,8 @@ def extract_spectrograms(
|
|||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
attn_mask,
|
||||
speaker_ids,
|
||||
speaker_embeddings,
|
||||
d_vectors,
|
||||
)
|
||||
|
||||
for idx in range(text_input.shape[0]):
|
||||
|
@ -240,28 +221,22 @@ def extract_spectrograms(
|
|||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data, symbols, phonemes, model_characters, speaker_mapping
|
||||
global meta_data, speaker_manager
|
||||
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
if "characters" in c.keys() and c["characters"]:
|
||||
symbols, phonemes = make_symbols(**c.characters)
|
||||
|
||||
# set model characters
|
||||
model_characters = phonemes if c.use_phonemes else symbols
|
||||
num_chars = len(model_characters)
|
||||
|
||||
# load data instances
|
||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True, ignore_generated_eval=True)
|
||||
|
||||
# use eval and training partitions
|
||||
meta_data = meta_data_train + meta_data_eval
|
||||
|
||||
# parse speakers
|
||||
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, None)
|
||||
speaker_manager = get_speaker_manager(c, args, meta_data_train)
|
||||
|
||||
# setup model
|
||||
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim)
|
||||
model = setup_model(c)
|
||||
|
||||
# restore model
|
||||
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
|
||||
|
@ -299,6 +274,5 @@ if __name__ == "__main__":
|
|||
args = parser.parse_args()
|
||||
|
||||
c = load_config(args.config_path)
|
||||
c.audio["do_trim_silence"] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel
|
||||
|
||||
c.audio.trim_silence = False
|
||||
main(args)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
"""Find all the unique characters in a dataset"""
|
||||
import argparse
|
||||
from argparse import RawTextHelpFormatter
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.datasets import load_meta_data
|
||||
from TTS.config import load_config
|
||||
|
||||
|
||||
|
@ -9,7 +9,6 @@ def main():
|
|||
# pylint: disable=bad-option-value
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Find all the unique characters or phonemes in a dataset.\n\n"""
|
||||
"""\n\n"""
|
||||
"""
|
||||
Example runs:
|
||||
|
||||
|
@ -23,9 +22,10 @@ def main():
|
|||
args = parser.parse_args()
|
||||
|
||||
c = load_config(args.config_path)
|
||||
|
||||
# load all datasets
|
||||
train_items, dev_items = load_meta_data(c.datasets, eval_split=True, ignore_generated_eval=True)
|
||||
items = train_items + dev_items
|
||||
train_items, eval_items = load_meta_data(c.datasets, eval_split=True, ignore_generated_eval=True)
|
||||
items = train_items + eval_items
|
||||
|
||||
texts = "".join(item[0] for item in items)
|
||||
chars = set(texts)
|
||||
|
|
|
@ -157,7 +157,7 @@ def main():
|
|||
parser.add_argument(
|
||||
"--speaker_wav",
|
||||
nargs="+",
|
||||
help="wav file(s) to condition a multi-speaker TTS model with a Speaker Encoder. You can give multiple file paths. The x_vectors is computed as their average.",
|
||||
help="wav file(s) to condition a multi-speaker TTS model with a Speaker Encoder. You can give multiple file paths. The d_vectors is computed as their average.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--gst_style", help="Wav path file for GST stylereference.", default=None)
|
||||
|
@ -239,7 +239,7 @@ def main():
|
|||
print(
|
||||
" > Available speaker ids: (Set --speaker_idx flag to one of these values to use the multi-speaker model."
|
||||
)
|
||||
print(synthesizer.speaker_manager.speaker_ids)
|
||||
print(synthesizer.tts_model.speaker_manager.speaker_ids)
|
||||
return
|
||||
|
||||
# check the arguments against a multi-speaker model.
|
||||
|
|
|
@ -1,572 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from random import randrange
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||
from TTS.tts.layers.losses import AlignTTSLoss
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.speakers import parse_speakers
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.arguments import init_training
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.distribute import init_distributed, reduce_tensor
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.training import NoamLR, setup_torch_training_env
|
||||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
|
||||
def setup_loader(ap, r, is_val=False, verbose=False):
|
||||
if is_val and not config.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
dataset = MyDataset(
|
||||
r,
|
||||
config.text_cleaner,
|
||||
compute_linear_spec=False,
|
||||
meta_data=meta_data_eval if is_val else meta_data_train,
|
||||
ap=ap,
|
||||
tp=config.characters,
|
||||
add_blank=config["add_blank"],
|
||||
batch_group_size=0 if is_val else config.batch_group_size * config.batch_size,
|
||||
min_seq_len=config.min_seq_len,
|
||||
max_seq_len=config.max_seq_len,
|
||||
phoneme_cache_path=config.phoneme_cache_path,
|
||||
use_phonemes=config.use_phonemes,
|
||||
phoneme_language=config.phoneme_language,
|
||||
enable_eos_bos=config.enable_eos_bos_chars,
|
||||
use_noise_augment=not is_val,
|
||||
verbose=verbose,
|
||||
speaker_mapping=speaker_mapping
|
||||
if config.use_speaker_embedding and config.use_external_speaker_embedding_file
|
||||
else None,
|
||||
)
|
||||
|
||||
if config.use_phonemes and config.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
dataset.compute_input_seq(config.num_loader_workers)
|
||||
dataset.sort_items()
|
||||
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_val else config.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=config.num_val_loader_workers if is_val else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
|
||||
def format_data(data):
|
||||
# setup input data
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
speaker_names = data[2]
|
||||
mel_input = data[4].permute(0, 2, 1) # B x D x T
|
||||
mel_lengths = data[5]
|
||||
item_idx = data[7]
|
||||
avg_text_length = torch.mean(text_lengths.float())
|
||||
avg_spec_length = torch.mean(mel_lengths.float())
|
||||
|
||||
if config.use_speaker_embedding:
|
||||
if config.use_external_speaker_embedding_file:
|
||||
# return precomputed embedding vector
|
||||
speaker_c = data[8]
|
||||
else:
|
||||
# return speaker_id to be used by an embedding layer
|
||||
speaker_c = [speaker_mapping[speaker_name] for speaker_name in speaker_names]
|
||||
speaker_c = torch.LongTensor(speaker_c)
|
||||
else:
|
||||
speaker_c = None
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
text_input = text_input.cuda(non_blocking=True)
|
||||
text_lengths = text_lengths.cuda(non_blocking=True)
|
||||
mel_input = mel_input.cuda(non_blocking=True)
|
||||
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
||||
if speaker_c is not None:
|
||||
speaker_c = speaker_c.cuda(non_blocking=True)
|
||||
return text_input, text_lengths, mel_input, mel_lengths, speaker_c, avg_text_length, avg_spec_length, item_idx
|
||||
|
||||
|
||||
def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch, training_phase):
|
||||
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(len(data_loader.dataset) / (config.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / config.batch_size)
|
||||
end_time = time.time()
|
||||
c_logger.print_train_start()
|
||||
scaler = torch.cuda.amp.GradScaler() if config.mixed_precision else None
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_targets,
|
||||
mel_lengths,
|
||||
speaker_c,
|
||||
avg_text_length,
|
||||
avg_spec_length,
|
||||
_,
|
||||
) = format_data(data)
|
||||
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward pass model
|
||||
with torch.cuda.amp.autocast(enabled=config.mixed_precision):
|
||||
decoder_output, dur_output, dur_mas_output, alignments, _, _, logp = model.forward(
|
||||
text_input, text_lengths, mel_targets, mel_lengths, g=speaker_c, phase=training_phase
|
||||
)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(
|
||||
logp,
|
||||
decoder_output,
|
||||
mel_targets,
|
||||
mel_lengths,
|
||||
dur_output,
|
||||
dur_mas_output,
|
||||
text_lengths,
|
||||
global_step,
|
||||
phase=training_phase,
|
||||
)
|
||||
|
||||
# backward pass with loss scaling
|
||||
if config.mixed_precision:
|
||||
scaler.scale(loss_dict["loss"]).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss_dict["loss"].backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
# setup lr
|
||||
if config.noam_schedule:
|
||||
scheduler.step()
|
||||
|
||||
# current_lr
|
||||
current_lr = optimizer.param_groups[0]["lr"]
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
|
||||
loss_dict["align_error"] = align_error
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus)
|
||||
loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus)
|
||||
loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus)
|
||||
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
for key, value in loss_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict_new[key] = value
|
||||
else:
|
||||
loss_dict_new[key] = value.item()
|
||||
loss_dict = loss_dict_new
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values["avg_" + key] = value
|
||||
update_train_values["avg_loader_time"] = loader_time
|
||||
update_train_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
# print training progress
|
||||
if global_step % config.print_step == 0:
|
||||
log_dict = {
|
||||
"avg_spec_length": [avg_spec_length, 1], # value, precision
|
||||
"avg_text_length": [avg_text_length, 1],
|
||||
"step_time": [step_time, 4],
|
||||
"loader_time": [loader_time, 2],
|
||||
"current_lr": current_lr,
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# Plot Training Iter Stats
|
||||
# reduce TB load
|
||||
if global_step % config.tb_plot_step == 0:
|
||||
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
|
||||
if global_step % config.save_step == 0:
|
||||
if config.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
global_step,
|
||||
epoch,
|
||||
1,
|
||||
OUT_PATH,
|
||||
model_characters,
|
||||
model_loss=loss_dict["loss"],
|
||||
)
|
||||
|
||||
# wait all kernels to be completed
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Diagnostic visualizations
|
||||
if decoder_output is not None:
|
||||
idx = np.random.randint(mel_targets.shape[0])
|
||||
pred_spec = decoder_output[idx].detach().data.cpu().numpy().T
|
||||
gt_spec = mel_targets[idx].data.cpu().numpy().T
|
||||
align_img = alignments[idx].data.cpu()
|
||||
|
||||
figures = {
|
||||
"prediction": plot_spectrogram(pred_spec, ap),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||
"alignment": plot_alignment(align_img),
|
||||
}
|
||||
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, config.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
# print epoch stats
|
||||
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
||||
|
||||
# Plot Epoch Stats
|
||||
if args.rank == 0:
|
||||
epoch_stats = {"epoch_time": epoch_time}
|
||||
epoch_stats.update(keep_avg.avg_values)
|
||||
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
||||
if config.tb_model_param_stats:
|
||||
tb_logger.tb_model_weights(model, global_step)
|
||||
return keep_avg.avg_values, global_step
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(data_loader, model, criterion, ap, global_step, epoch, training_phase):
|
||||
model.eval()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
c_logger.print_eval_start()
|
||||
if data_loader is not None:
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_targets, mel_lengths, speaker_c, _, _, _ = format_data(data)
|
||||
|
||||
# forward pass model
|
||||
with torch.cuda.amp.autocast(enabled=config.mixed_precision):
|
||||
decoder_output, dur_output, dur_mas_output, alignments, _, _, logp = model.forward(
|
||||
text_input, text_lengths, mel_targets, mel_lengths, g=speaker_c, phase=training_phase
|
||||
)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(
|
||||
logp,
|
||||
decoder_output,
|
||||
mel_targets,
|
||||
mel_lengths,
|
||||
dur_output,
|
||||
dur_mas_output,
|
||||
text_lengths,
|
||||
global_step,
|
||||
phase=training_phase,
|
||||
)
|
||||
|
||||
# step time
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# compute alignment score
|
||||
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
|
||||
loss_dict["align_error"] = align_error
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus)
|
||||
loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus)
|
||||
loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus)
|
||||
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
for key, value in loss_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict_new[key] = value
|
||||
else:
|
||||
loss_dict_new[key] = value.item()
|
||||
loss_dict = loss_dict_new
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values["avg_" + key] = value
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
if config.print_eval:
|
||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# Diagnostic visualizations
|
||||
idx = np.random.randint(mel_targets.shape[0])
|
||||
pred_spec = decoder_output[idx].detach().data.cpu().numpy().T
|
||||
gt_spec = mel_targets[idx].data.cpu().numpy().T
|
||||
align_img = alignments[idx].data.cpu()
|
||||
|
||||
eval_figures = {
|
||||
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||
"alignment": plot_alignment(align_img, output_fig=False),
|
||||
}
|
||||
|
||||
# Sample audio
|
||||
eval_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, config.audio["sample_rate"])
|
||||
|
||||
# Plot Validation Stats
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
tb_logger.tb_eval_figures(global_step, eval_figures)
|
||||
|
||||
if args.rank == 0 and epoch >= config.test_delay_epochs:
|
||||
if config.test_sentences_file:
|
||||
with open(config.test_sentences_file, "r") as f:
|
||||
test_sentences = [s.strip() for s in f.readlines()]
|
||||
else:
|
||||
test_sentences = [
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963.",
|
||||
]
|
||||
|
||||
# test sentences
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
print(" | > Synthesizing test sentences")
|
||||
if config.use_speaker_embedding:
|
||||
if config.use_external_speaker_embedding_file:
|
||||
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]][
|
||||
"embedding"
|
||||
]
|
||||
speaker_id = None
|
||||
else:
|
||||
speaker_id = 0
|
||||
speaker_embedding = None
|
||||
else:
|
||||
speaker_id = None
|
||||
speaker_embedding = None
|
||||
|
||||
for idx, test_sentence in enumerate(test_sentences):
|
||||
try:
|
||||
wav, alignment, _, postnet_output, _, _ = synthesis(
|
||||
model,
|
||||
test_sentence,
|
||||
config,
|
||||
use_cuda,
|
||||
ap,
|
||||
speaker_id=speaker_id,
|
||||
speaker_embedding=speaker_embedding,
|
||||
style_wav=None,
|
||||
truncated=False,
|
||||
enable_eos_bos_chars=config.enable_eos_bos_chars, # pylint: disable=unused-argument
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
)
|
||||
|
||||
file_path = os.path.join(AUDIO_PATH, str(global_step))
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
|
||||
ap.save_wav(wav, file_path)
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment)
|
||||
except: # pylint: disable=bare-except
|
||||
print(" !! Error creating Test Sentence -", idx)
|
||||
traceback.print_exc()
|
||||
tb_logger.tb_test_audios(global_step, test_audios, config.audio["sample_rate"])
|
||||
tb_logger.tb_test_figures(global_step, test_figures)
|
||||
return keep_avg.avg_values
|
||||
|
||||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**config.audio.to_dict())
|
||||
if config.has("characters") and config.characters:
|
||||
symbols, phonemes = make_symbols(**config.characters.to_dict())
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id, config.distributed["backend"], config.distributed["url"])
|
||||
|
||||
# set model characters
|
||||
model_characters = phonemes if config.use_phonemes else symbols
|
||||
num_chars = len(model_characters)
|
||||
|
||||
# load data instances
|
||||
meta_data_train, meta_data_eval = load_meta_data(config.datasets, eval_split=True)
|
||||
|
||||
# parse speakers
|
||||
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(config, args, meta_data_train, OUT_PATH)
|
||||
|
||||
# setup model
|
||||
model = setup_model(num_chars, num_speakers, config, speaker_embedding_dim=speaker_embedding_dim)
|
||||
optimizer = RAdam(model.parameters(), lr=config.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
|
||||
criterion = AlignTTSLoss(config)
|
||||
|
||||
if args.restore_path:
|
||||
print(f" > Restoring from {os.path.basename(args.restore_path)} ...")
|
||||
checkpoint = torch.load(args.restore_path, map_location="cpu")
|
||||
try:
|
||||
# TODO: fix optimizer init, model.cuda() needs to be called before
|
||||
# optimizer restore
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
if config.reinit_layers:
|
||||
raise RuntimeError
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
except: # pylint: disable=bare-except
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], config)
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group["initial_lr"] = config.lr
|
||||
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
criterion.cuda()
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
model = DDP_th(model, device_ids=[args.rank])
|
||||
|
||||
if config.noam_schedule:
|
||||
scheduler = NoamLR(optimizer, warmup_steps=config.warmup_steps, last_epoch=args.restore_step - 1)
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
num_params = count_parameters(model)
|
||||
print("\n > Model has {} parameters".format(num_params), flush=True)
|
||||
|
||||
if args.restore_step == 0 or not args.best_path:
|
||||
best_loss = float("inf")
|
||||
print(" > Starting with inf best loss.")
|
||||
else:
|
||||
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {best_loss}.")
|
||||
keep_all_best = config.keep_all_best
|
||||
keep_after = config.keep_after # void if keep_all_best False
|
||||
|
||||
# define dataloaders
|
||||
train_loader = setup_loader(ap, 1, is_val=False, verbose=True)
|
||||
eval_loader = setup_loader(ap, 1, is_val=True, verbose=True)
|
||||
|
||||
global_step = args.restore_step
|
||||
|
||||
def set_phase():
|
||||
"""Set AlignTTS training phase"""
|
||||
if isinstance(config.phase_start_steps, list):
|
||||
vals = [i < global_step for i in config.phase_start_steps]
|
||||
if not True in vals:
|
||||
phase = 0
|
||||
else:
|
||||
phase = (
|
||||
len(config.phase_start_steps)
|
||||
- [i < global_step for i in config.phase_start_steps][::-1].index(True)
|
||||
- 1
|
||||
)
|
||||
else:
|
||||
phase = None
|
||||
return phase
|
||||
|
||||
for epoch in range(0, config.epochs):
|
||||
cur_phase = set_phase()
|
||||
print(f"\n > Current AlignTTS phase: {cur_phase}")
|
||||
c_logger.print_epoch_start(epoch, config.epochs)
|
||||
train_avg_loss_dict, global_step = train(
|
||||
train_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch, cur_phase
|
||||
)
|
||||
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch, cur_phase)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = train_avg_loss_dict["avg_loss"]
|
||||
if config.run_eval:
|
||||
target_loss = eval_avg_loss_dict["avg_loss"]
|
||||
best_loss = save_best_model(
|
||||
target_loss,
|
||||
best_loss,
|
||||
model,
|
||||
optimizer,
|
||||
global_step,
|
||||
epoch,
|
||||
1,
|
||||
OUT_PATH,
|
||||
model_characters,
|
||||
keep_all_best=keep_all_best,
|
||||
keep_after=keep_after,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
|
||||
|
||||
try:
|
||||
main(args)
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
|
@ -13,8 +13,8 @@ from TTS.speaker_encoder.dataset import SpeakerEncoderDataset
|
|||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss, SoftmaxAngleProtoLoss
|
||||
from TTS.speaker_encoder.utils.generic_utils import save_best_model, setup_model
|
||||
from TTS.speaker_encoder.utils.visual import plot_embeddings
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.utils.arguments import init_training
|
||||
from TTS.trainer import init_training
|
||||
from TTS.tts.datasets import load_meta_data
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.radam import RAdam
|
||||
|
@ -164,7 +164,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
elif c.loss == "angleproto":
|
||||
criterion = AngleProtoLoss()
|
||||
elif c.loss == "softmaxproto":
|
||||
criterion = SoftmaxAngleProtoLoss(c.model["proj_dim"], num_speakers)
|
||||
criterion = SoftmaxAngleProtoLoss(c.model_params["proj_dim"], num_speakers)
|
||||
else:
|
||||
raise Exception("The %s not is a loss supported" % c.loss)
|
||||
|
||||
|
|
|
@ -1,598 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Train Glow TTS model."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from random import randrange
|
||||
|
||||
import torch
|
||||
|
||||
# DISTRIBUTED
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||
from TTS.tts.layers.losses import GlowTTSLoss
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.speakers import parse_speakers
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.arguments import init_training
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.distribute import init_distributed, reduce_tensor
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.training import NoamLR, setup_torch_training_env
|
||||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||
|
||||
|
||||
def setup_loader(ap, r, is_val=False, verbose=False):
|
||||
if is_val and not config.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
dataset = MyDataset(
|
||||
r,
|
||||
config.text_cleaner,
|
||||
compute_linear_spec=False,
|
||||
meta_data=meta_data_eval if is_val else meta_data_train,
|
||||
ap=ap,
|
||||
tp=config.characters,
|
||||
add_blank=config["add_blank"],
|
||||
batch_group_size=0 if is_val else config.batch_group_size * config.batch_size,
|
||||
min_seq_len=config.min_seq_len,
|
||||
max_seq_len=config.max_seq_len,
|
||||
phoneme_cache_path=config.phoneme_cache_path,
|
||||
use_phonemes=config.use_phonemes,
|
||||
phoneme_language=config.phoneme_language,
|
||||
enable_eos_bos=config.enable_eos_bos_chars,
|
||||
use_noise_augment=not is_val,
|
||||
verbose=verbose,
|
||||
speaker_mapping=speaker_mapping
|
||||
if config.use_speaker_embedding and config.use_external_speaker_embedding_file
|
||||
else None,
|
||||
)
|
||||
|
||||
if config.use_phonemes and config.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
dataset.compute_input_seq(config.num_loader_workers)
|
||||
dataset.sort_items()
|
||||
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_val else config.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=config.num_val_loader_workers if is_val else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
|
||||
def format_data(data):
|
||||
# setup input data
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
speaker_names = data[2]
|
||||
mel_input = data[4].permute(0, 2, 1) # B x D x T
|
||||
mel_lengths = data[5]
|
||||
item_idx = data[7]
|
||||
attn_mask = data[9]
|
||||
avg_text_length = torch.mean(text_lengths.float())
|
||||
avg_spec_length = torch.mean(mel_lengths.float())
|
||||
|
||||
if config.use_speaker_embedding:
|
||||
if config.use_external_speaker_embedding_file:
|
||||
# return precomputed embedding vector
|
||||
speaker_c = data[8]
|
||||
else:
|
||||
# return speaker_id to be used by an embedding layer
|
||||
speaker_c = [speaker_mapping[speaker_name] for speaker_name in speaker_names]
|
||||
speaker_c = torch.LongTensor(speaker_c)
|
||||
else:
|
||||
speaker_c = None
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
text_input = text_input.cuda(non_blocking=True)
|
||||
text_lengths = text_lengths.cuda(non_blocking=True)
|
||||
mel_input = mel_input.cuda(non_blocking=True)
|
||||
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
||||
if speaker_c is not None:
|
||||
speaker_c = speaker_c.cuda(non_blocking=True)
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.cuda(non_blocking=True)
|
||||
return (
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
speaker_c,
|
||||
avg_text_length,
|
||||
avg_spec_length,
|
||||
attn_mask,
|
||||
item_idx,
|
||||
)
|
||||
|
||||
|
||||
def data_depended_init(data_loader, model):
|
||||
"""Data depended initialization for activation normalization."""
|
||||
if hasattr(model, "module"):
|
||||
for f in model.module.decoder.flows:
|
||||
if getattr(f, "set_ddi", False):
|
||||
f.set_ddi(True)
|
||||
else:
|
||||
for f in model.decoder.flows:
|
||||
if getattr(f, "set_ddi", False):
|
||||
f.set_ddi(True)
|
||||
|
||||
model.train()
|
||||
print(" > Data depended initialization ... ")
|
||||
num_iter = 0
|
||||
with torch.no_grad():
|
||||
for _, data in enumerate(data_loader):
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_input, mel_lengths, spekaer_embed, _, _, attn_mask, _ = format_data(data)
|
||||
|
||||
# forward pass model
|
||||
_ = model.forward(text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=spekaer_embed)
|
||||
if num_iter == config.data_dep_init_steps:
|
||||
break
|
||||
num_iter += 1
|
||||
|
||||
if hasattr(model, "module"):
|
||||
for f in model.module.decoder.flows:
|
||||
if getattr(f, "set_ddi", False):
|
||||
f.set_ddi(False)
|
||||
else:
|
||||
for f in model.decoder.flows:
|
||||
if getattr(f, "set_ddi", False):
|
||||
f.set_ddi(False)
|
||||
return model
|
||||
|
||||
|
||||
def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch):
|
||||
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(len(data_loader.dataset) / (config.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / config.batch_size)
|
||||
end_time = time.time()
|
||||
c_logger.print_train_start()
|
||||
scaler = torch.cuda.amp.GradScaler() if config.mixed_precision else None
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
speaker_c,
|
||||
avg_text_length,
|
||||
avg_spec_length,
|
||||
attn_mask,
|
||||
_,
|
||||
) = format_data(data)
|
||||
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward pass model
|
||||
with torch.cuda.amp.autocast(enabled=config.mixed_precision):
|
||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c
|
||||
)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, o_dur_log, o_total_dur, text_lengths)
|
||||
|
||||
# backward pass with loss scaling
|
||||
if config.mixed_precision:
|
||||
scaler.scale(loss_dict["loss"]).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss_dict["loss"].backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
# setup lr
|
||||
if config.noam_schedule:
|
||||
scheduler.step()
|
||||
|
||||
# current_lr
|
||||
current_lr = optimizer.param_groups[0]["lr"]
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
|
||||
loss_dict["align_error"] = align_error
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict["log_mle"] = reduce_tensor(loss_dict["log_mle"].data, num_gpus)
|
||||
loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus)
|
||||
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
for key, value in loss_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict_new[key] = value
|
||||
else:
|
||||
loss_dict_new[key] = value.item()
|
||||
loss_dict = loss_dict_new
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values["avg_" + key] = value
|
||||
update_train_values["avg_loader_time"] = loader_time
|
||||
update_train_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
# print training progress
|
||||
if global_step % config.print_step == 0:
|
||||
log_dict = {
|
||||
"avg_spec_length": [avg_spec_length, 1], # value, precision
|
||||
"avg_text_length": [avg_text_length, 1],
|
||||
"step_time": [step_time, 4],
|
||||
"loader_time": [loader_time, 2],
|
||||
"current_lr": current_lr,
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# Plot Training Iter Stats
|
||||
# reduce TB load
|
||||
if global_step % config.tb_plot_step == 0:
|
||||
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
|
||||
if global_step % config.save_step == 0:
|
||||
if config.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
global_step,
|
||||
epoch,
|
||||
1,
|
||||
OUT_PATH,
|
||||
model_characters,
|
||||
model_loss=loss_dict["loss"],
|
||||
)
|
||||
|
||||
# wait all kernels to be completed
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Diagnostic visualizations
|
||||
# direct pass on model for spec predictions
|
||||
target_speaker = None if speaker_c is None else speaker_c[:1]
|
||||
|
||||
if hasattr(model, "module"):
|
||||
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
||||
else:
|
||||
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
||||
|
||||
spec_pred = spec_pred.permute(0, 2, 1)
|
||||
gt_spec = mel_input.permute(0, 2, 1)
|
||||
const_spec = spec_pred[0].data.cpu().numpy()
|
||||
gt_spec = gt_spec[0].data.cpu().numpy()
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
|
||||
figures = {
|
||||
"prediction": plot_spectrogram(const_spec, ap),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||
"alignment": plot_alignment(align_img),
|
||||
}
|
||||
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_melspectrogram(const_spec.T)
|
||||
tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, config.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
# print epoch stats
|
||||
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
||||
|
||||
# Plot Epoch Stats
|
||||
if args.rank == 0:
|
||||
epoch_stats = {"epoch_time": epoch_time}
|
||||
epoch_stats.update(keep_avg.avg_values)
|
||||
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
||||
if config.tb_model_param_stats:
|
||||
tb_logger.tb_model_weights(model, global_step)
|
||||
return keep_avg.avg_values, global_step
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
||||
model.eval()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
c_logger.print_eval_start()
|
||||
if data_loader is not None:
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_c, _, _, attn_mask, _ = format_data(data)
|
||||
|
||||
# forward pass model
|
||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c
|
||||
)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, o_dur_log, o_total_dur, text_lengths)
|
||||
|
||||
# step time
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# compute alignment score
|
||||
align_error = 1 - alignment_diagonal_score(alignments)
|
||||
loss_dict["align_error"] = align_error
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict["log_mle"] = reduce_tensor(loss_dict["log_mle"].data, num_gpus)
|
||||
loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus)
|
||||
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
for key, value in loss_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict_new[key] = value
|
||||
else:
|
||||
loss_dict_new[key] = value.item()
|
||||
loss_dict = loss_dict_new
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values["avg_" + key] = value
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
if config.print_eval:
|
||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# Diagnostic visualizations
|
||||
# direct pass on model for spec predictions
|
||||
target_speaker = None if speaker_c is None else speaker_c[:1]
|
||||
if hasattr(model, "module"):
|
||||
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
||||
else:
|
||||
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
||||
spec_pred = spec_pred.permute(0, 2, 1)
|
||||
gt_spec = mel_input.permute(0, 2, 1)
|
||||
|
||||
const_spec = spec_pred[0].data.cpu().numpy()
|
||||
gt_spec = gt_spec[0].data.cpu().numpy()
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
|
||||
eval_figures = {
|
||||
"prediction": plot_spectrogram(const_spec, ap),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||
"alignment": plot_alignment(align_img),
|
||||
}
|
||||
|
||||
# Sample audio
|
||||
eval_audio = ap.inv_melspectrogram(const_spec.T)
|
||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, config.audio["sample_rate"])
|
||||
|
||||
# Plot Validation Stats
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
tb_logger.tb_eval_figures(global_step, eval_figures)
|
||||
|
||||
if args.rank == 0 and epoch >= config.test_delay_epochs:
|
||||
if config.test_sentences_file:
|
||||
with open(config.test_sentences_file, "r") as f:
|
||||
test_sentences = [s.strip() for s in f.readlines()]
|
||||
else:
|
||||
test_sentences = [
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963.",
|
||||
]
|
||||
|
||||
# test sentences
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
print(" | > Synthesizing test sentences")
|
||||
if config.use_speaker_embedding:
|
||||
if config.use_external_speaker_embedding_file:
|
||||
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]][
|
||||
"embedding"
|
||||
]
|
||||
speaker_id = None
|
||||
else:
|
||||
speaker_id = 0
|
||||
speaker_embedding = None
|
||||
else:
|
||||
speaker_id = None
|
||||
speaker_embedding = None
|
||||
|
||||
style_wav = config.style_wav_for_test
|
||||
for idx, test_sentence in enumerate(test_sentences):
|
||||
try:
|
||||
wav, alignment, _, postnet_output, _, _ = synthesis(
|
||||
model,
|
||||
test_sentence,
|
||||
config,
|
||||
use_cuda,
|
||||
ap,
|
||||
speaker_id=speaker_id,
|
||||
speaker_embedding=speaker_embedding,
|
||||
style_wav=style_wav,
|
||||
truncated=False,
|
||||
enable_eos_bos_chars=config.enable_eos_bos_chars, # pylint: disable=unused-argument
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
)
|
||||
|
||||
file_path = os.path.join(AUDIO_PATH, str(global_step))
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
|
||||
ap.save_wav(wav, file_path)
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment)
|
||||
except: # pylint: disable=bare-except
|
||||
print(" !! Error creating Test Sentence -", idx)
|
||||
traceback.print_exc()
|
||||
tb_logger.tb_test_audios(global_step, test_audios, config.audio["sample_rate"])
|
||||
tb_logger.tb_test_figures(global_step, test_figures)
|
||||
return keep_avg.avg_values
|
||||
|
||||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**config.audio.to_dict())
|
||||
if config.has("characters") and config.characters:
|
||||
symbols, phonemes = make_symbols(**config.characters.to_dict())
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id, config.distributed["backend"], config.distributed["url"])
|
||||
|
||||
# set model characters
|
||||
model_characters = phonemes if config.use_phonemes else symbols
|
||||
num_chars = len(model_characters)
|
||||
|
||||
# load data instances
|
||||
meta_data_train, meta_data_eval = load_meta_data(config.datasets)
|
||||
|
||||
# parse speakers
|
||||
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(config, args, meta_data_train, OUT_PATH)
|
||||
|
||||
# setup model
|
||||
model = setup_model(num_chars, num_speakers, config, speaker_embedding_dim=speaker_embedding_dim)
|
||||
optimizer = RAdam(model.parameters(), lr=config.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
|
||||
criterion = GlowTTSLoss()
|
||||
|
||||
if args.restore_path:
|
||||
print(f" > Restoring from {os.path.basename(args.restore_path)} ...")
|
||||
checkpoint = torch.load(args.restore_path, map_location="cpu")
|
||||
try:
|
||||
# TODO: fix optimizer init, model.cuda() needs to be called before
|
||||
# optimizer restore
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
except: # pylint: disable=bare-except
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], config)
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group["initial_lr"] = config.lr
|
||||
print(f" > Model restored from step {checkpoint['step']:d}", flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
criterion.cuda()
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
model = DDP_th(model, device_ids=[args.rank])
|
||||
|
||||
if config.noam_schedule:
|
||||
scheduler = NoamLR(optimizer, warmup_steps=config.warmup_steps, last_epoch=args.restore_step - 1)
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
num_params = count_parameters(model)
|
||||
print("\n > Model has {} parameters".format(num_params), flush=True)
|
||||
|
||||
if args.restore_step == 0 or not args.best_path:
|
||||
best_loss = float("inf")
|
||||
print(" > Starting with inf best loss.")
|
||||
else:
|
||||
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {best_loss}.")
|
||||
keep_all_best = config.keep_all_best
|
||||
keep_after = config.keep_after # void if keep_all_best False
|
||||
|
||||
# define dataloaders
|
||||
train_loader = setup_loader(ap, 1, is_val=False, verbose=True)
|
||||
eval_loader = setup_loader(ap, 1, is_val=True, verbose=True)
|
||||
|
||||
global_step = args.restore_step
|
||||
model = data_depended_init(train_loader, model)
|
||||
for epoch in range(0, config.epochs):
|
||||
c_logger.print_epoch_start(epoch, config.epochs)
|
||||
train_avg_loss_dict, global_step = train(
|
||||
train_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch
|
||||
)
|
||||
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = train_avg_loss_dict["avg_loss"]
|
||||
if config.run_eval:
|
||||
target_loss = eval_avg_loss_dict["avg_loss"]
|
||||
best_loss = save_best_model(
|
||||
target_loss,
|
||||
best_loss,
|
||||
model,
|
||||
optimizer,
|
||||
global_step,
|
||||
epoch,
|
||||
config.r,
|
||||
OUT_PATH,
|
||||
model_characters,
|
||||
keep_all_best=keep_all_best,
|
||||
keep_after=keep_after,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
|
||||
|
||||
try:
|
||||
main(args)
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
|
@ -1,578 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from random import randrange
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# DISTRIBUTED
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||
from TTS.tts.layers.losses import SpeedySpeechLoss
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.speakers import parse_speakers
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.arguments import init_training
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.distribute import init_distributed, reduce_tensor
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.training import NoamLR, setup_torch_training_env
|
||||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||
|
||||
|
||||
def setup_loader(ap, r, is_val=False, verbose=False):
|
||||
if is_val and not config.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
dataset = MyDataset(
|
||||
r,
|
||||
config.text_cleaner,
|
||||
compute_linear_spec=False,
|
||||
meta_data=meta_data_eval if is_val else meta_data_train,
|
||||
ap=ap,
|
||||
tp=config.characters,
|
||||
add_blank=config["add_blank"],
|
||||
batch_group_size=0 if is_val else config.batch_group_size * config.batch_size,
|
||||
min_seq_len=config.min_seq_len,
|
||||
max_seq_len=config.max_seq_len,
|
||||
phoneme_cache_path=config.phoneme_cache_path,
|
||||
use_phonemes=config.use_phonemes,
|
||||
phoneme_language=config.phoneme_language,
|
||||
enable_eos_bos=config.enable_eos_bos_chars,
|
||||
use_noise_augment=not is_val,
|
||||
verbose=verbose,
|
||||
speaker_mapping=speaker_mapping
|
||||
if config.use_speaker_embedding and config.use_external_speaker_embedding_file
|
||||
else None,
|
||||
)
|
||||
|
||||
if config.use_phonemes and config.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
dataset.compute_input_seq(config.num_loader_workers)
|
||||
dataset.sort_items()
|
||||
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_val else config.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=config.num_val_loader_workers if is_val else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
|
||||
def format_data(data):
|
||||
# setup input data
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
speaker_names = data[2]
|
||||
mel_input = data[4].permute(0, 2, 1) # B x D x T
|
||||
mel_lengths = data[5]
|
||||
item_idx = data[7]
|
||||
attn_mask = data[9]
|
||||
avg_text_length = torch.mean(text_lengths.float())
|
||||
avg_spec_length = torch.mean(mel_lengths.float())
|
||||
|
||||
if config.use_speaker_embedding:
|
||||
if config.use_external_speaker_embedding_file:
|
||||
# return precomputed embedding vector
|
||||
speaker_c = data[8]
|
||||
else:
|
||||
# return speaker_id to be used by an embedding layer
|
||||
speaker_c = [speaker_mapping[speaker_name] for speaker_name in speaker_names]
|
||||
speaker_c = torch.LongTensor(speaker_c)
|
||||
else:
|
||||
speaker_c = None
|
||||
# compute durations from attention mask
|
||||
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
|
||||
for idx, am in enumerate(attn_mask):
|
||||
# compute raw durations
|
||||
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
|
||||
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
|
||||
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
|
||||
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
|
||||
dur[c_idxs] = counts
|
||||
# smooth the durations and set any 0 duration to 1
|
||||
# by cutting off from the largest duration indeces.
|
||||
extra_frames = dur.sum() - mel_lengths[idx]
|
||||
largest_idxs = torch.argsort(-dur)[:extra_frames]
|
||||
dur[largest_idxs] -= 1
|
||||
assert (
|
||||
dur.sum() == mel_lengths[idx]
|
||||
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
||||
durations[idx, : text_lengths[idx]] = dur
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
text_input = text_input.cuda(non_blocking=True)
|
||||
text_lengths = text_lengths.cuda(non_blocking=True)
|
||||
mel_input = mel_input.cuda(non_blocking=True)
|
||||
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
||||
if speaker_c is not None:
|
||||
speaker_c = speaker_c.cuda(non_blocking=True)
|
||||
attn_mask = attn_mask.cuda(non_blocking=True)
|
||||
durations = durations.cuda(non_blocking=True)
|
||||
return (
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
speaker_c,
|
||||
avg_text_length,
|
||||
avg_spec_length,
|
||||
attn_mask,
|
||||
durations,
|
||||
item_idx,
|
||||
)
|
||||
|
||||
|
||||
def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch):
|
||||
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(len(data_loader.dataset) / (config.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / config.batch_size)
|
||||
end_time = time.time()
|
||||
c_logger.print_train_start()
|
||||
scaler = torch.cuda.amp.GradScaler() if config.mixed_precision else None
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_targets,
|
||||
mel_lengths,
|
||||
speaker_c,
|
||||
avg_text_length,
|
||||
avg_spec_length,
|
||||
_,
|
||||
dur_target,
|
||||
_,
|
||||
) = format_data(data)
|
||||
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward pass model
|
||||
with torch.cuda.amp.autocast(enabled=config.mixed_precision):
|
||||
decoder_output, dur_output, alignments = model.forward(
|
||||
text_input, text_lengths, mel_lengths, dur_target, g=speaker_c
|
||||
)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(
|
||||
decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths
|
||||
)
|
||||
|
||||
# backward pass with loss scaling
|
||||
if config.mixed_precision:
|
||||
scaler.scale(loss_dict["loss"]).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss_dict["loss"].backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
# setup lr
|
||||
if config.noam_schedule:
|
||||
scheduler.step()
|
||||
|
||||
# current_lr
|
||||
current_lr = optimizer.param_groups[0]["lr"]
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
|
||||
loss_dict["align_error"] = align_error
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus)
|
||||
loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus)
|
||||
loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus)
|
||||
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
for key, value in loss_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict_new[key] = value
|
||||
else:
|
||||
loss_dict_new[key] = value.item()
|
||||
loss_dict = loss_dict_new
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values["avg_" + key] = value
|
||||
update_train_values["avg_loader_time"] = loader_time
|
||||
update_train_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
# print training progress
|
||||
if global_step % config.print_step == 0:
|
||||
log_dict = {
|
||||
"avg_spec_length": [avg_spec_length, 1], # value, precision
|
||||
"avg_text_length": [avg_text_length, 1],
|
||||
"step_time": [step_time, 4],
|
||||
"loader_time": [loader_time, 2],
|
||||
"current_lr": current_lr,
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# Plot Training Iter Stats
|
||||
# reduce TB load
|
||||
if global_step % config.tb_plot_step == 0:
|
||||
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
|
||||
if global_step % config.save_step == 0:
|
||||
if config.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
global_step,
|
||||
epoch,
|
||||
1,
|
||||
OUT_PATH,
|
||||
model_characters,
|
||||
model_loss=loss_dict["loss"],
|
||||
)
|
||||
|
||||
# wait all kernels to be completed
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Diagnostic visualizations
|
||||
idx = np.random.randint(mel_targets.shape[0])
|
||||
pred_spec = decoder_output[idx].detach().data.cpu().numpy().T
|
||||
gt_spec = mel_targets[idx].data.cpu().numpy().T
|
||||
align_img = alignments[idx].data.cpu()
|
||||
|
||||
figures = {
|
||||
"prediction": plot_spectrogram(pred_spec, ap),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||
"alignment": plot_alignment(align_img),
|
||||
}
|
||||
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, config.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
# print epoch stats
|
||||
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
||||
|
||||
# Plot Epoch Stats
|
||||
if args.rank == 0:
|
||||
epoch_stats = {"epoch_time": epoch_time}
|
||||
epoch_stats.update(keep_avg.avg_values)
|
||||
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
||||
if config.tb_model_param_stats:
|
||||
tb_logger.tb_model_weights(model, global_step)
|
||||
return keep_avg.avg_values, global_step
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
||||
model.eval()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
c_logger.print_eval_start()
|
||||
if data_loader is not None:
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_targets, mel_lengths, speaker_c, _, _, _, dur_target, _ = format_data(data)
|
||||
|
||||
# forward pass model
|
||||
with torch.cuda.amp.autocast(enabled=config.mixed_precision):
|
||||
decoder_output, dur_output, alignments = model.forward(
|
||||
text_input, text_lengths, mel_lengths, dur_target, g=speaker_c
|
||||
)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(
|
||||
decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths
|
||||
)
|
||||
|
||||
# step time
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# compute alignment score
|
||||
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
|
||||
loss_dict["align_error"] = align_error
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus)
|
||||
loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus)
|
||||
loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus)
|
||||
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
for key, value in loss_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict_new[key] = value
|
||||
else:
|
||||
loss_dict_new[key] = value.item()
|
||||
loss_dict = loss_dict_new
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values["avg_" + key] = value
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
if config.print_eval:
|
||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# Diagnostic visualizations
|
||||
idx = np.random.randint(mel_targets.shape[0])
|
||||
pred_spec = decoder_output[idx].detach().data.cpu().numpy().T
|
||||
gt_spec = mel_targets[idx].data.cpu().numpy().T
|
||||
align_img = alignments[idx].data.cpu()
|
||||
|
||||
eval_figures = {
|
||||
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||
"alignment": plot_alignment(align_img, output_fig=False),
|
||||
}
|
||||
|
||||
# Sample audio
|
||||
eval_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, config.audio["sample_rate"])
|
||||
|
||||
# Plot Validation Stats
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
tb_logger.tb_eval_figures(global_step, eval_figures)
|
||||
|
||||
if args.rank == 0 and epoch >= config.test_delay_epochs:
|
||||
if config.test_sentences_file:
|
||||
with open(config.test_sentences_file, "r") as f:
|
||||
test_sentences = [s.strip() for s in f.readlines()]
|
||||
else:
|
||||
test_sentences = [
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963.",
|
||||
]
|
||||
|
||||
# test sentences
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
print(" | > Synthesizing test sentences")
|
||||
if config.use_speaker_embedding:
|
||||
if config.use_external_speaker_embedding_file:
|
||||
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]][
|
||||
"embedding"
|
||||
]
|
||||
speaker_id = None
|
||||
else:
|
||||
speaker_id = 0
|
||||
speaker_embedding = None
|
||||
else:
|
||||
speaker_id = None
|
||||
speaker_embedding = None
|
||||
|
||||
for idx, test_sentence in enumerate(test_sentences):
|
||||
try:
|
||||
wav, alignment, _, postnet_output, _, _ = synthesis(
|
||||
model,
|
||||
test_sentence,
|
||||
config,
|
||||
use_cuda,
|
||||
ap,
|
||||
speaker_id=speaker_id,
|
||||
speaker_embedding=speaker_embedding,
|
||||
style_wav=None,
|
||||
truncated=False,
|
||||
enable_eos_bos_chars=config.enable_eos_bos_chars, # pylint: disable=unused-argument
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
)
|
||||
|
||||
file_path = os.path.join(AUDIO_PATH, str(global_step))
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
|
||||
ap.save_wav(wav, file_path)
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment)
|
||||
except: # pylint: disable=bare-except
|
||||
print(" !! Error creating Test Sentence -", idx)
|
||||
traceback.print_exc()
|
||||
tb_logger.tb_test_audios(global_step, test_audios, config.audio["sample_rate"])
|
||||
tb_logger.tb_test_figures(global_step, test_figures)
|
||||
return keep_avg.avg_values
|
||||
|
||||
|
||||
# FIXME: move args definition/parsing inside of main?
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**config.audio.to_dict())
|
||||
if config.characters is not None:
|
||||
symbols, phonemes = make_symbols(**config.characters.to_dict())
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id, config.distributed["backend"], config.distributed["url"])
|
||||
|
||||
# set model characters
|
||||
model_characters = phonemes if config.use_phonemes else symbols
|
||||
num_chars = len(model_characters)
|
||||
|
||||
# load data instances
|
||||
meta_data_train, meta_data_eval = load_meta_data(config.datasets, eval_split=True)
|
||||
|
||||
# set the portion of the data used for training if set in config.json
|
||||
if config.has("train_portion"):
|
||||
meta_data_train = meta_data_train[: int(len(meta_data_train) * config.train_portion)]
|
||||
if config.has("eval_portion"):
|
||||
meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * config.eval_portion)]
|
||||
|
||||
# parse speakers
|
||||
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(config, args, meta_data_train, OUT_PATH)
|
||||
|
||||
# setup model
|
||||
model = setup_model(num_chars, num_speakers, config, speaker_embedding_dim=speaker_embedding_dim)
|
||||
optimizer = RAdam(model.parameters(), lr=config.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
|
||||
criterion = SpeedySpeechLoss(config)
|
||||
|
||||
if args.restore_path:
|
||||
print(f" > Restoring from {os.path.basename(args.restore_path)} ...")
|
||||
checkpoint = torch.load(args.restore_path, map_location="cpu")
|
||||
try:
|
||||
# TODO: fix optimizer init, model.cuda() needs to be called before
|
||||
# optimizer restore
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
if config.reinit_layers:
|
||||
raise RuntimeError
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
except: # pylint: disable=bare-except
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], config)
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group["initial_lr"] = config.lr
|
||||
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
criterion.cuda()
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
model = DDP_th(model, device_ids=[args.rank])
|
||||
|
||||
if config.noam_schedule:
|
||||
scheduler = NoamLR(optimizer, warmup_steps=config.warmup_steps, last_epoch=args.restore_step - 1)
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
num_params = count_parameters(model)
|
||||
print("\n > Model has {} parameters".format(num_params), flush=True)
|
||||
|
||||
if args.restore_step == 0 or not args.best_path:
|
||||
best_loss = float("inf")
|
||||
print(" > Starting with inf best loss.")
|
||||
else:
|
||||
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {best_loss}.")
|
||||
keep_all_best = config.keep_all_best
|
||||
keep_after = config.keep_after # void if keep_all_best False
|
||||
|
||||
# define dataloaders
|
||||
train_loader = setup_loader(ap, 1, is_val=False, verbose=True)
|
||||
eval_loader = setup_loader(ap, 1, is_val=True, verbose=True)
|
||||
|
||||
global_step = args.restore_step
|
||||
for epoch in range(0, config.epochs):
|
||||
c_logger.print_epoch_start(epoch, config.epochs)
|
||||
train_avg_loss_dict, global_step = train(
|
||||
train_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch
|
||||
)
|
||||
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = train_avg_loss_dict["avg_loss"]
|
||||
if config.run_eval:
|
||||
target_loss = eval_avg_loss_dict["avg_loss"]
|
||||
best_loss = save_best_model(
|
||||
target_loss,
|
||||
best_loss,
|
||||
model,
|
||||
optimizer,
|
||||
global_step,
|
||||
epoch,
|
||||
config.r,
|
||||
OUT_PATH,
|
||||
model_characters,
|
||||
keep_all_best=keep_all_best,
|
||||
keep_after=keep_after,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
|
||||
|
||||
try:
|
||||
main(args)
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
|
@ -1,749 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Trains Tacotron based TTS models."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from random import randrange
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||
from TTS.tts.layers.losses import TacotronLoss
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.speakers import parse_speakers
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.arguments import init_training
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.distribute import DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.training import (
|
||||
NoamLR,
|
||||
adam_weight_decay,
|
||||
check_update,
|
||||
gradual_training_scheduler,
|
||||
set_weight_decay,
|
||||
setup_torch_training_env,
|
||||
)
|
||||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||
|
||||
|
||||
def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
|
||||
if is_val and not config.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
if dataset is None:
|
||||
dataset = MyDataset(
|
||||
r,
|
||||
config.text_cleaner,
|
||||
compute_linear_spec=config.model.lower() == "tacotron",
|
||||
meta_data=meta_data_eval if is_val else meta_data_train,
|
||||
ap=ap,
|
||||
tp=config.characters,
|
||||
add_blank=config["add_blank"],
|
||||
batch_group_size=0 if is_val else config.batch_group_size * config.batch_size,
|
||||
min_seq_len=config.min_seq_len,
|
||||
max_seq_len=config.max_seq_len,
|
||||
phoneme_cache_path=config.phoneme_cache_path,
|
||||
use_phonemes=config.use_phonemes,
|
||||
phoneme_language=config.phoneme_language,
|
||||
enable_eos_bos=config.enable_eos_bos_chars,
|
||||
verbose=verbose,
|
||||
speaker_mapping=(
|
||||
speaker_mapping
|
||||
if (config.use_speaker_embedding and config.use_external_speaker_embedding_file)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
if config.use_phonemes and config.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
dataset.compute_input_seq(config.num_loader_workers)
|
||||
dataset.sort_items()
|
||||
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_val else config.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=config.num_val_loader_workers if is_val else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
|
||||
def format_data(data):
|
||||
# setup input data
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
speaker_names = data[2]
|
||||
linear_input = data[3] if config.model.lower() in ["tacotron"] else None
|
||||
mel_input = data[4]
|
||||
mel_lengths = data[5]
|
||||
stop_targets = data[6]
|
||||
max_text_length = torch.max(text_lengths.float())
|
||||
max_spec_length = torch.max(mel_lengths.float())
|
||||
|
||||
if config.use_speaker_embedding:
|
||||
if config.use_external_speaker_embedding_file:
|
||||
speaker_embeddings = data[8]
|
||||
speaker_ids = None
|
||||
else:
|
||||
speaker_ids = [speaker_mapping[speaker_name] for speaker_name in speaker_names]
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
speaker_embeddings = None
|
||||
else:
|
||||
speaker_embeddings = None
|
||||
speaker_ids = None
|
||||
|
||||
# set stop targets view, we predict a single stop token per iteration.
|
||||
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // config.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
text_input = text_input.cuda(non_blocking=True)
|
||||
text_lengths = text_lengths.cuda(non_blocking=True)
|
||||
mel_input = mel_input.cuda(non_blocking=True)
|
||||
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
||||
linear_input = linear_input.cuda(non_blocking=True) if config.model.lower() in ["tacotron"] else None
|
||||
stop_targets = stop_targets.cuda(non_blocking=True)
|
||||
if speaker_ids is not None:
|
||||
speaker_ids = speaker_ids.cuda(non_blocking=True)
|
||||
if speaker_embeddings is not None:
|
||||
speaker_embeddings = speaker_embeddings.cuda(non_blocking=True)
|
||||
|
||||
return (
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
linear_input,
|
||||
stop_targets,
|
||||
speaker_ids,
|
||||
speaker_embeddings,
|
||||
max_text_length,
|
||||
max_spec_length,
|
||||
)
|
||||
|
||||
|
||||
def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap, global_step, epoch, scaler, scaler_st):
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(len(data_loader.dataset) / (config.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / config.batch_size)
|
||||
end_time = time.time()
|
||||
c_logger.print_train_start()
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
linear_input,
|
||||
stop_targets,
|
||||
speaker_ids,
|
||||
speaker_embeddings,
|
||||
max_text_length,
|
||||
max_spec_length,
|
||||
) = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
|
||||
# setup lr
|
||||
if config.noam_schedule:
|
||||
scheduler.step()
|
||||
|
||||
optimizer.zero_grad()
|
||||
if optimizer_st:
|
||||
optimizer_st.zero_grad()
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=config.mixed_precision):
|
||||
# forward pass model
|
||||
if config.bidirectional_decoder or config.double_decoder_consistency:
|
||||
(
|
||||
decoder_output,
|
||||
postnet_output,
|
||||
alignments,
|
||||
stop_tokens,
|
||||
decoder_backward_output,
|
||||
alignments_backward,
|
||||
) = model(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
speaker_ids=speaker_ids,
|
||||
speaker_embeddings=speaker_embeddings,
|
||||
)
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
speaker_ids=speaker_ids,
|
||||
speaker_embeddings=speaker_embeddings,
|
||||
)
|
||||
decoder_backward_output = None
|
||||
alignments_backward = None
|
||||
|
||||
# set the [alignment] lengths wrt reduction factor for guided attention
|
||||
if mel_lengths.max() % model.decoder.r != 0:
|
||||
alignment_lengths = (
|
||||
mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))
|
||||
) // model.decoder.r
|
||||
else:
|
||||
alignment_lengths = mel_lengths // model.decoder.r
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(
|
||||
postnet_output,
|
||||
decoder_output,
|
||||
mel_input,
|
||||
linear_input,
|
||||
stop_tokens,
|
||||
stop_targets,
|
||||
mel_lengths,
|
||||
decoder_backward_output,
|
||||
alignments,
|
||||
alignment_lengths,
|
||||
alignments_backward,
|
||||
text_lengths,
|
||||
)
|
||||
|
||||
# check nan loss
|
||||
if torch.isnan(loss_dict["loss"]).any():
|
||||
raise RuntimeError(f"Detected NaN loss at step {global_step}.")
|
||||
|
||||
# optimizer step
|
||||
if config.mixed_precision:
|
||||
# model optimizer step in mixed precision mode
|
||||
scaler.scale(loss_dict["loss"]).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
optimizer, current_lr = adam_weight_decay(optimizer)
|
||||
grad_norm, _ = check_update(model, config.grad_clip, ignore_stopnet=True)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
# stopnet optimizer step
|
||||
if config.separate_stopnet:
|
||||
scaler_st.scale(loss_dict["stopnet_loss"]).backward()
|
||||
scaler.unscale_(optimizer_st)
|
||||
optimizer_st, _ = adam_weight_decay(optimizer_st)
|
||||
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
||||
scaler_st.step(optimizer)
|
||||
scaler_st.update()
|
||||
else:
|
||||
grad_norm_st = 0
|
||||
else:
|
||||
# main model optimizer step
|
||||
loss_dict["loss"].backward()
|
||||
optimizer, current_lr = adam_weight_decay(optimizer)
|
||||
grad_norm, _ = check_update(model, config.grad_clip, ignore_stopnet=True)
|
||||
optimizer.step()
|
||||
|
||||
# stopnet optimizer step
|
||||
if config.separate_stopnet:
|
||||
loss_dict["stopnet_loss"].backward()
|
||||
optimizer_st, _ = adam_weight_decay(optimizer_st)
|
||||
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
||||
optimizer_st.step()
|
||||
else:
|
||||
grad_norm_st = 0
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(alignments)
|
||||
loss_dict["align_error"] = align_error
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict["postnet_loss"] = reduce_tensor(loss_dict["postnet_loss"].data, num_gpus)
|
||||
loss_dict["decoder_loss"] = reduce_tensor(loss_dict["decoder_loss"].data, num_gpus)
|
||||
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
|
||||
loss_dict["stopnet_loss"] = (
|
||||
reduce_tensor(loss_dict["stopnet_loss"].data, num_gpus) if config.stopnet else loss_dict["stopnet_loss"]
|
||||
)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
for key, value in loss_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict_new[key] = value
|
||||
else:
|
||||
loss_dict_new[key] = value.item()
|
||||
loss_dict = loss_dict_new
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values["avg_" + key] = value
|
||||
update_train_values["avg_loader_time"] = loader_time
|
||||
update_train_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
# print training progress
|
||||
if global_step % config.print_step == 0:
|
||||
log_dict = {
|
||||
"max_spec_length": [max_spec_length, 1], # value, precision
|
||||
"max_text_length": [max_text_length, 1],
|
||||
"step_time": [step_time, 4],
|
||||
"loader_time": [loader_time, 2],
|
||||
"current_lr": current_lr,
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# Plot Training Iter Stats
|
||||
# reduce TB load
|
||||
if global_step % config.tb_plot_step == 0:
|
||||
iter_stats = {
|
||||
"lr": current_lr,
|
||||
"grad_norm": grad_norm,
|
||||
"grad_norm_st": grad_norm_st,
|
||||
"step_time": step_time,
|
||||
}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
|
||||
if global_step % config.save_step == 0:
|
||||
if config.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
global_step,
|
||||
epoch,
|
||||
model.decoder.r,
|
||||
OUT_PATH,
|
||||
optimizer_st=optimizer_st,
|
||||
model_loss=loss_dict["postnet_loss"],
|
||||
characters=model_characters,
|
||||
scaler=scaler.state_dict() if config.mixed_precision else None,
|
||||
)
|
||||
|
||||
# Diagnostic visualizations
|
||||
const_spec = postnet_output[0].data.cpu().numpy()
|
||||
gt_spec = (
|
||||
linear_input[0].data.cpu().numpy()
|
||||
if config.model in ["Tacotron", "TacotronGST"]
|
||||
else mel_input[0].data.cpu().numpy()
|
||||
)
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
|
||||
figures = {
|
||||
"prediction": plot_spectrogram(const_spec, ap, output_fig=False),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||
"alignment": plot_alignment(align_img, output_fig=False),
|
||||
}
|
||||
|
||||
if config.bidirectional_decoder or config.double_decoder_consistency:
|
||||
figures["alignment_backward"] = plot_alignment(
|
||||
alignments_backward[0].data.cpu().numpy(), output_fig=False
|
||||
)
|
||||
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
if config.model in ["Tacotron", "TacotronGST"]:
|
||||
train_audio = ap.inv_spectrogram(const_spec.T)
|
||||
else:
|
||||
train_audio = ap.inv_melspectrogram(const_spec.T)
|
||||
tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, config.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
# print epoch stats
|
||||
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
||||
|
||||
# Plot Epoch Stats
|
||||
if args.rank == 0:
|
||||
epoch_stats = {"epoch_time": epoch_time}
|
||||
epoch_stats.update(keep_avg.avg_values)
|
||||
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
||||
if config.tb_model_param_stats:
|
||||
tb_logger.tb_model_weights(model, global_step)
|
||||
return keep_avg.avg_values, global_step
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
||||
model.eval()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
c_logger.print_eval_start()
|
||||
if data_loader is not None:
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
linear_input,
|
||||
stop_targets,
|
||||
speaker_ids,
|
||||
speaker_embeddings,
|
||||
_,
|
||||
_,
|
||||
) = format_data(data)
|
||||
assert mel_input.shape[1] % model.decoder.r == 0
|
||||
|
||||
# forward pass model
|
||||
if config.bidirectional_decoder or config.double_decoder_consistency:
|
||||
(
|
||||
decoder_output,
|
||||
postnet_output,
|
||||
alignments,
|
||||
stop_tokens,
|
||||
decoder_backward_output,
|
||||
alignments_backward,
|
||||
) = model(
|
||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
decoder_backward_output = None
|
||||
alignments_backward = None
|
||||
|
||||
# set the alignment lengths wrt reduction factor for guided attention
|
||||
if mel_lengths.max() % model.decoder.r != 0:
|
||||
alignment_lengths = (
|
||||
mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))
|
||||
) // model.decoder.r
|
||||
else:
|
||||
alignment_lengths = mel_lengths // model.decoder.r
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(
|
||||
postnet_output,
|
||||
decoder_output,
|
||||
mel_input,
|
||||
linear_input,
|
||||
stop_tokens,
|
||||
stop_targets,
|
||||
mel_lengths,
|
||||
decoder_backward_output,
|
||||
alignments,
|
||||
alignment_lengths,
|
||||
alignments_backward,
|
||||
text_lengths,
|
||||
)
|
||||
|
||||
# step time
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# compute alignment score
|
||||
align_error = 1 - alignment_diagonal_score(alignments)
|
||||
loss_dict["align_error"] = align_error
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict["postnet_loss"] = reduce_tensor(loss_dict["postnet_loss"].data, num_gpus)
|
||||
loss_dict["decoder_loss"] = reduce_tensor(loss_dict["decoder_loss"].data, num_gpus)
|
||||
if config.stopnet:
|
||||
loss_dict["stopnet_loss"] = reduce_tensor(loss_dict["stopnet_loss"].data, num_gpus)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
for key, value in loss_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict_new[key] = value
|
||||
else:
|
||||
loss_dict_new[key] = value.item()
|
||||
loss_dict = loss_dict_new
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values["avg_" + key] = value
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
if config.print_eval:
|
||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# Diagnostic visualizations
|
||||
idx = np.random.randint(mel_input.shape[0])
|
||||
const_spec = postnet_output[idx].data.cpu().numpy()
|
||||
gt_spec = (
|
||||
linear_input[idx].data.cpu().numpy()
|
||||
if config.model in ["Tacotron", "TacotronGST"]
|
||||
else mel_input[idx].data.cpu().numpy()
|
||||
)
|
||||
align_img = alignments[idx].data.cpu().numpy()
|
||||
|
||||
eval_figures = {
|
||||
"prediction": plot_spectrogram(const_spec, ap, output_fig=False),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||
"alignment": plot_alignment(align_img, output_fig=False),
|
||||
}
|
||||
|
||||
# Sample audio
|
||||
if config.model.lower() in ["tacotron"]:
|
||||
eval_audio = ap.inv_spectrogram(const_spec.T)
|
||||
else:
|
||||
eval_audio = ap.inv_melspectrogram(const_spec.T)
|
||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, config.audio["sample_rate"])
|
||||
|
||||
# Plot Validation Stats
|
||||
|
||||
if config.bidirectional_decoder or config.double_decoder_consistency:
|
||||
align_b_img = alignments_backward[idx].data.cpu().numpy()
|
||||
eval_figures["alignment2"] = plot_alignment(align_b_img, output_fig=False)
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
tb_logger.tb_eval_figures(global_step, eval_figures)
|
||||
|
||||
if args.rank == 0 and epoch > config.test_delay_epochs:
|
||||
if config.test_sentences_file:
|
||||
with open(config.test_sentences_file, "r") as f:
|
||||
test_sentences = [s.strip() for s in f.readlines()]
|
||||
else:
|
||||
test_sentences = [
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963.",
|
||||
]
|
||||
|
||||
# test sentences
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
print(" | > Synthesizing test sentences")
|
||||
speaker_id = 0 if config.use_speaker_embedding else None
|
||||
speaker_embedding = (
|
||||
speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]]["embedding"]
|
||||
if config.use_external_speaker_embedding_file and config.use_speaker_embedding
|
||||
else None
|
||||
)
|
||||
style_wav = config.gst_style_input
|
||||
if style_wav is None and config.gst is not None:
|
||||
# inicialize GST with zero dict.
|
||||
style_wav = {}
|
||||
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
|
||||
for i in range(config.gst["gst_num_style_tokens"]):
|
||||
style_wav[str(i)] = 0
|
||||
for idx, test_sentence in enumerate(test_sentences):
|
||||
try:
|
||||
wav, alignment, decoder_output, postnet_output, stop_tokens, _ = synthesis(
|
||||
model,
|
||||
test_sentence,
|
||||
config,
|
||||
use_cuda,
|
||||
ap,
|
||||
speaker_id=speaker_id,
|
||||
speaker_embedding=speaker_embedding,
|
||||
style_wav=style_wav,
|
||||
truncated=False,
|
||||
enable_eos_bos_chars=config.enable_eos_bos_chars, # pylint: disable=unused-argument
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
)
|
||||
|
||||
file_path = os.path.join(AUDIO_PATH, str(global_step))
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
|
||||
ap.save_wav(wav, file_path)
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap, output_fig=False)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
|
||||
except: # pylint: disable=bare-except
|
||||
print(" !! Error creating Test Sentence -", idx)
|
||||
traceback.print_exc()
|
||||
tb_logger.tb_test_audios(global_step, test_audios, config.audio["sample_rate"])
|
||||
tb_logger.tb_test_figures(global_step, test_figures)
|
||||
return keep_avg.avg_values
|
||||
|
||||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data_train, meta_data_eval, speaker_mapping, symbols, phonemes, model_characters
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**config.audio.to_dict())
|
||||
|
||||
# setup custom characters if set in config file.
|
||||
if config.characters is not None:
|
||||
symbols, phonemes = make_symbols(**config.characters.to_dict())
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id, config.distributed["backend"], config.distributed["url"])
|
||||
num_chars = len(phonemes) if config.use_phonemes else len(symbols)
|
||||
model_characters = phonemes if config.use_phonemes else symbols
|
||||
|
||||
# load data instances
|
||||
meta_data_train, meta_data_eval = load_meta_data(config.datasets)
|
||||
|
||||
# set the portion of the data used for training
|
||||
if config.has("train_portion"):
|
||||
meta_data_train = meta_data_train[: int(len(meta_data_train) * config.train_portion)]
|
||||
if config.has("eval_portion"):
|
||||
meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * config.eval_portion)]
|
||||
|
||||
# parse speakers
|
||||
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(config, args, meta_data_train, OUT_PATH)
|
||||
|
||||
model = setup_model(num_chars, num_speakers, config, speaker_embedding_dim)
|
||||
|
||||
# scalers for mixed precision training
|
||||
scaler = torch.cuda.amp.GradScaler() if config.mixed_precision else None
|
||||
scaler_st = torch.cuda.amp.GradScaler() if config.mixed_precision and config.separate_stopnet else None
|
||||
|
||||
params = set_weight_decay(model, config.wd)
|
||||
optimizer = RAdam(params, lr=config.lr, weight_decay=0)
|
||||
if config.stopnet and config.separate_stopnet:
|
||||
optimizer_st = RAdam(model.decoder.stopnet.parameters(), lr=config.lr, weight_decay=0)
|
||||
else:
|
||||
optimizer_st = None
|
||||
|
||||
# setup criterion
|
||||
criterion = TacotronLoss(config, stopnet_pos_weight=config.stopnet_pos_weight, ga_sigma=0.4)
|
||||
if args.restore_path:
|
||||
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
|
||||
checkpoint = torch.load(args.restore_path, map_location="cpu")
|
||||
try:
|
||||
print(" > Restoring Model...")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
# optimizer restore
|
||||
print(" > Restoring Optimizer...")
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
if "scaler" in checkpoint and config.mixed_precision:
|
||||
print(" > Restoring AMP Scaler...")
|
||||
scaler.load_state_dict(checkpoint["scaler"])
|
||||
except (KeyError, RuntimeError):
|
||||
print(" > Partial model initialization...")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], config)
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = config.lr
|
||||
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
criterion.cuda()
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
model = apply_gradient_allreduce(model)
|
||||
|
||||
if config.noam_schedule:
|
||||
scheduler = NoamLR(optimizer, warmup_steps=config.warmup_steps, last_epoch=args.restore_step - 1)
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
num_params = count_parameters(model)
|
||||
print("\n > Model has {} parameters".format(num_params), flush=True)
|
||||
|
||||
if args.restore_step == 0 or not args.best_path:
|
||||
best_loss = float("inf")
|
||||
print(" > Starting with inf best loss.")
|
||||
else:
|
||||
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {best_loss}.")
|
||||
keep_all_best = config.keep_all_best
|
||||
keep_after = config.keep_after # void if keep_all_best False
|
||||
|
||||
# define data loaders
|
||||
train_loader = setup_loader(ap, model.decoder.r, is_val=False, verbose=True)
|
||||
eval_loader = setup_loader(ap, model.decoder.r, is_val=True)
|
||||
|
||||
global_step = args.restore_step
|
||||
for epoch in range(0, config.epochs):
|
||||
c_logger.print_epoch_start(epoch, config.epochs)
|
||||
# set gradual training
|
||||
if config.gradual_training is not None:
|
||||
r, config.batch_size = gradual_training_scheduler(global_step, config)
|
||||
config.r = r
|
||||
model.decoder.set_r(r)
|
||||
if config.bidirectional_decoder:
|
||||
model.decoder_backward.set_r(r)
|
||||
train_loader.dataset.outputs_per_step = r
|
||||
eval_loader.dataset.outputs_per_step = r
|
||||
train_loader = setup_loader(ap, model.decoder.r, is_val=False, dataset=train_loader.dataset)
|
||||
eval_loader = setup_loader(ap, model.decoder.r, is_val=True, dataset=eval_loader.dataset)
|
||||
print("\n > Number of output frames:", model.decoder.r)
|
||||
# train one epoch
|
||||
train_avg_loss_dict, global_step = train(
|
||||
train_loader,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
optimizer_st,
|
||||
scheduler,
|
||||
ap,
|
||||
global_step,
|
||||
epoch,
|
||||
scaler,
|
||||
scaler_st,
|
||||
)
|
||||
# eval one epoch
|
||||
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = train_avg_loss_dict["avg_postnet_loss"]
|
||||
if config.run_eval:
|
||||
target_loss = eval_avg_loss_dict["avg_postnet_loss"]
|
||||
best_loss = save_best_model(
|
||||
target_loss,
|
||||
best_loss,
|
||||
model,
|
||||
optimizer,
|
||||
global_step,
|
||||
epoch,
|
||||
config.r,
|
||||
OUT_PATH,
|
||||
model_characters,
|
||||
keep_all_best=keep_all_best,
|
||||
keep_after=keep_after,
|
||||
scaler=scaler.state_dict() if config.mixed_precision else None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
|
||||
try:
|
||||
main(args)
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
|
@ -0,0 +1,14 @@
|
|||
import sys
|
||||
|
||||
from TTS.trainer import Trainer, init_training
|
||||
|
||||
|
||||
def main():
|
||||
"""Run 🐸TTS trainer from terminal. This is also necessary to run DDP training by ```distribute.py```"""
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger, cudnn_benchmark=False)
|
||||
trainer.fit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,27 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
from TTS.trainer import Trainer, init_training
|
||||
from TTS.utils.generic_utils import remove_experiment_folder
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
args, config, output_path, _, c_logger, tb_logger = init_training(sys.argv)
|
||||
trainer = Trainer(args, config, output_path, c_logger, tb_logger)
|
||||
trainer.fit()
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(output_path)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(output_path)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -1,638 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
# TODO: mixed precision training
|
||||
"""Trains GAN based vocoder model."""
|
||||
|
||||
import itertools
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from inspect import signature
|
||||
|
||||
import torch
|
||||
|
||||
# DISTRIBUTED
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.utils.arguments import init_training
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.distribute import init_distributed
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.training import setup_torch_training_env
|
||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
||||
from TTS.vocoder.utils.generic_utils import plot_results, setup_discriminator, setup_generator
|
||||
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
||||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
||||
|
||||
|
||||
def setup_loader(ap, is_val=False, verbose=False):
|
||||
loader = None
|
||||
if not is_val or c.run_eval:
|
||||
dataset = GANDataset(
|
||||
ap=ap,
|
||||
items=eval_data if is_val else train_data,
|
||||
seq_len=c.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=c.pad_short,
|
||||
conv_pad=c.conv_pad,
|
||||
return_pairs=c.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in c else False,
|
||||
is_training=not is_val,
|
||||
return_segments=not is_val,
|
||||
use_noise_augment=c.use_noise_augment,
|
||||
use_cache=c.use_cache,
|
||||
verbose=verbose,
|
||||
)
|
||||
dataset.shuffle_mapping()
|
||||
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1 if is_val else c.batch_size,
|
||||
shuffle=num_gpus == 0,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
|
||||
def format_data(data):
|
||||
if isinstance(data[0], list):
|
||||
x_G, y_G = data[0]
|
||||
x_D, y_D = data[1]
|
||||
if use_cuda:
|
||||
x_G = x_G.cuda(non_blocking=True)
|
||||
y_G = y_G.cuda(non_blocking=True)
|
||||
x_D = x_D.cuda(non_blocking=True)
|
||||
y_D = y_D.cuda(non_blocking=True)
|
||||
return x_G, y_G, x_D, y_D
|
||||
x, y = data
|
||||
if use_cuda:
|
||||
x = x.cuda(non_blocking=True)
|
||||
y = y.cuda(non_blocking=True)
|
||||
return x, y, None, None
|
||||
|
||||
|
||||
def train(
|
||||
model_G,
|
||||
criterion_G,
|
||||
optimizer_G,
|
||||
model_D,
|
||||
criterion_D,
|
||||
optimizer_D,
|
||||
scheduler_G,
|
||||
scheduler_D,
|
||||
ap,
|
||||
global_step,
|
||||
epoch,
|
||||
):
|
||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||
model_G.train()
|
||||
model_D.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
end_time = time.time()
|
||||
c_logger.print_train_start()
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
c_G, y_G, c_D, y_D = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
|
||||
##############################
|
||||
# GENERATOR
|
||||
##############################
|
||||
|
||||
# generator pass
|
||||
y_hat = model_G(c_G)
|
||||
y_hat_sub = None
|
||||
y_G_sub = None
|
||||
y_hat_vis = y_hat # for visualization
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
y_hat_sub = y_hat
|
||||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
y_hat_vis = y_hat
|
||||
y_G_sub = model_G.pqmf_analysis(y_G)
|
||||
|
||||
scores_fake, feats_fake, feats_real = None, None, None
|
||||
if global_step > c.steps_to_start_discriminator:
|
||||
|
||||
# run D with or without cond. features
|
||||
if len(signature(model_D.forward).parameters) == 2:
|
||||
D_out_fake = model_D(y_hat, c_G)
|
||||
else:
|
||||
D_out_fake = model_D(y_hat)
|
||||
D_out_real = None
|
||||
|
||||
if c.use_feat_match_loss:
|
||||
with torch.no_grad():
|
||||
D_out_real = model_D(y_G)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
feats_real = None
|
||||
else:
|
||||
# we don't need scores for real samples for training G since they are always 1
|
||||
_, feats_real = D_out_real
|
||||
else:
|
||||
scores_fake = D_out_fake
|
||||
|
||||
# compute losses
|
||||
loss_G_dict = criterion_G(
|
||||
y_hat=y_hat,
|
||||
y=y_G,
|
||||
scores_fake=scores_fake,
|
||||
feats_fake=feats_fake,
|
||||
feats_real=feats_real,
|
||||
y_hat_sub=y_hat_sub,
|
||||
y_sub=y_G_sub,
|
||||
)
|
||||
loss_G = loss_G_dict["G_loss"]
|
||||
|
||||
# optimizer generator
|
||||
optimizer_G.zero_grad()
|
||||
loss_G.backward()
|
||||
if c.gen_clip_grad > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model_G.parameters(), c.gen_clip_grad)
|
||||
optimizer_G.step()
|
||||
|
||||
loss_dict = dict()
|
||||
for key, value in loss_G_dict.items():
|
||||
if isinstance(value, int):
|
||||
loss_dict[key] = value
|
||||
else:
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
##############################
|
||||
# DISCRIMINATOR
|
||||
##############################
|
||||
if global_step >= c.steps_to_start_discriminator:
|
||||
# discriminator pass
|
||||
if c.diff_samples_for_G_and_D:
|
||||
# use a different sample than generator
|
||||
with torch.no_grad():
|
||||
y_hat = model_G(c_D)
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
else:
|
||||
# use the same samples as generator
|
||||
c_D = c_G.clone()
|
||||
y_D = y_G.clone()
|
||||
|
||||
# run D with or without cond. features
|
||||
if len(signature(model_D.forward).parameters) == 2:
|
||||
D_out_fake = model_D(y_hat.detach().clone(), c_D)
|
||||
D_out_real = model_D(y_D, c_D)
|
||||
else:
|
||||
D_out_fake = model_D(y_hat.detach())
|
||||
D_out_real = model_D(y_D)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
# model_D returns scores and features
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
scores_real, feats_real = None, None
|
||||
else:
|
||||
scores_real, feats_real = D_out_real
|
||||
else:
|
||||
# model D returns only scores
|
||||
scores_fake = D_out_fake
|
||||
scores_real = D_out_real
|
||||
|
||||
# compute losses
|
||||
loss_D_dict = criterion_D(scores_fake, scores_real)
|
||||
loss_D = loss_D_dict["D_loss"]
|
||||
|
||||
# optimizer discriminator
|
||||
optimizer_D.zero_grad()
|
||||
loss_D.backward()
|
||||
if c.disc_clip_grad > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model_D.parameters(), c.disc_clip_grad)
|
||||
optimizer_D.step()
|
||||
|
||||
for key, value in loss_D_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict[key] = value
|
||||
else:
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# get current learning rates
|
||||
current_lr_G = list(optimizer_G.param_groups)[0]["lr"]
|
||||
current_lr_D = list(optimizer_D.param_groups)[0]["lr"]
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values["avg_" + key] = value
|
||||
update_train_values["avg_loader_time"] = loader_time
|
||||
update_train_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
# print training stats
|
||||
if global_step % c.print_step == 0:
|
||||
log_dict = {
|
||||
"step_time": [step_time, 2],
|
||||
"loader_time": [loader_time, 4],
|
||||
"current_lr_G": current_lr_G,
|
||||
"current_lr_D": current_lr_D,
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# plot step stats
|
||||
if global_step % 10 == 0:
|
||||
iter_stats = {"lr_G": current_lr_G, "lr_D": current_lr_D, "step_time": step_time}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
|
||||
# save checkpoint
|
||||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(
|
||||
model_G,
|
||||
optimizer_G,
|
||||
scheduler_G,
|
||||
model_D,
|
||||
optimizer_D,
|
||||
scheduler_D,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=loss_dict,
|
||||
)
|
||||
|
||||
# compute spectrograms
|
||||
figures = plot_results(y_hat_vis, y_G, ap, global_step, "train")
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy()
|
||||
tb_logger.tb_train_audios(global_step, {"train/audio": sample_voice}, c.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
if scheduler_G is not None:
|
||||
scheduler_G.step()
|
||||
|
||||
if scheduler_D is not None:
|
||||
scheduler_D.step()
|
||||
|
||||
# print epoch stats
|
||||
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
||||
|
||||
# Plot Training Epoch Stats
|
||||
epoch_stats = {"epoch_time": epoch_time}
|
||||
epoch_stats.update(keep_avg.avg_values)
|
||||
if args.rank == 0:
|
||||
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
||||
# TODO: plot model stats
|
||||
# if c.tb_model_param_stats:
|
||||
# tb_logger.tb_model_weights(model, global_step)
|
||||
torch.cuda.empty_cache()
|
||||
return keep_avg.avg_values, global_step
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch):
|
||||
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
|
||||
model_G.eval()
|
||||
model_D.eval()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
end_time = time.time()
|
||||
c_logger.print_eval_start()
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
c_G, y_G, _, _ = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
|
||||
##############################
|
||||
# GENERATOR
|
||||
##############################
|
||||
|
||||
# generator pass
|
||||
y_hat = model_G(c_G)[:, :, : y_G.size(2)]
|
||||
y_hat_sub = None
|
||||
y_G_sub = None
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
y_hat_sub = y_hat
|
||||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
y_G_sub = model_G.pqmf_analysis(y_G)
|
||||
|
||||
scores_fake, feats_fake, feats_real = None, None, None
|
||||
if global_step > c.steps_to_start_discriminator:
|
||||
|
||||
if len(signature(model_D.forward).parameters) == 2:
|
||||
D_out_fake = model_D(y_hat, c_G)
|
||||
else:
|
||||
D_out_fake = model_D(y_hat)
|
||||
D_out_real = None
|
||||
|
||||
if c.use_feat_match_loss:
|
||||
with torch.no_grad():
|
||||
D_out_real = model_D(y_G)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
feats_real = None
|
||||
else:
|
||||
_, feats_real = D_out_real
|
||||
else:
|
||||
scores_fake = D_out_fake
|
||||
feats_fake, feats_real = None, None
|
||||
|
||||
# compute losses
|
||||
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, feats_real, y_hat_sub, y_G_sub)
|
||||
|
||||
loss_dict = dict()
|
||||
for key, value in loss_G_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict[key] = value
|
||||
else:
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
##############################
|
||||
# DISCRIMINATOR
|
||||
##############################
|
||||
|
||||
if global_step >= c.steps_to_start_discriminator:
|
||||
# discriminator pass
|
||||
with torch.no_grad():
|
||||
y_hat = model_G(c_G)[:, :, : y_G.size(2)]
|
||||
|
||||
# PQMF formatting
|
||||
if y_hat.shape[1] > 1:
|
||||
y_hat = model_G.pqmf_synthesis(y_hat)
|
||||
|
||||
# run D with or without cond. features
|
||||
if len(signature(model_D.forward).parameters) == 2:
|
||||
D_out_fake = model_D(y_hat.detach(), c_G)
|
||||
D_out_real = model_D(y_G, c_G)
|
||||
else:
|
||||
D_out_fake = model_D(y_hat.detach())
|
||||
D_out_real = model_D(y_G)
|
||||
|
||||
# format D outputs
|
||||
if isinstance(D_out_fake, tuple):
|
||||
scores_fake, feats_fake = D_out_fake
|
||||
if D_out_real is None:
|
||||
scores_real, feats_real = None, None
|
||||
else:
|
||||
scores_real, feats_real = D_out_real
|
||||
else:
|
||||
scores_fake = D_out_fake
|
||||
scores_real = D_out_real
|
||||
|
||||
# compute losses
|
||||
loss_D_dict = criterion_D(scores_fake, scores_real)
|
||||
|
||||
for key, value in loss_D_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict[key] = value
|
||||
else:
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# update avg stats
|
||||
update_eval_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_eval_values["avg_" + key] = value
|
||||
update_eval_values["avg_loader_time"] = loader_time
|
||||
update_eval_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_eval_values)
|
||||
|
||||
# print eval stats
|
||||
if c.print_eval:
|
||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# compute spectrograms
|
||||
figures = plot_results(y_hat, y_G, ap, global_step, "eval")
|
||||
tb_logger.tb_eval_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
predict_waveform = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
real_waveform = y_G[0].squeeze(0).cpu().numpy()
|
||||
tb_logger.tb_eval_audios(
|
||||
global_step, {"eval/audio": predict_waveform, "eval/real_waveformo": real_waveform}, c.audio["sample_rate"]
|
||||
)
|
||||
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
|
||||
# synthesize a full voice
|
||||
data_loader.return_segments = False
|
||||
torch.cuda.empty_cache()
|
||||
return keep_avg.avg_values
|
||||
|
||||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global train_data, eval_data
|
||||
print(f" > Loading wavs from: {c.data_path}")
|
||||
if c.feature_path is not None:
|
||||
print(f" > Loading features from: {c.feature_path}")
|
||||
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
|
||||
else:
|
||||
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
|
||||
|
||||
# setup audio processor
|
||||
ap = AudioProcessor(**c.audio.to_dict())
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"])
|
||||
|
||||
# setup models
|
||||
model_gen = setup_generator(c)
|
||||
model_disc = setup_discriminator(c)
|
||||
|
||||
# setup criterion
|
||||
criterion_gen = GeneratorLoss(c)
|
||||
criterion_disc = DiscriminatorLoss(c)
|
||||
|
||||
if use_cuda:
|
||||
model_gen.cuda()
|
||||
criterion_gen.cuda()
|
||||
model_disc.cuda()
|
||||
criterion_disc.cuda()
|
||||
|
||||
# setup optimizers
|
||||
# TODO: allow loading custom optimizers
|
||||
optimizer_gen = None
|
||||
optimizer_disc = None
|
||||
optimizer_gen = getattr(torch.optim, c.optimizer)
|
||||
optimizer_gen = optimizer_gen(model_gen.parameters(), lr=c.lr_gen, **c.optimizer_params)
|
||||
optimizer_disc = getattr(torch.optim, c.optimizer)
|
||||
|
||||
if c.discriminator_model == "hifigan_discriminator":
|
||||
optimizer_disc = optimizer_disc(
|
||||
itertools.chain(model_disc.msd.parameters(), model_disc.mpd.parameters()),
|
||||
lr=c.lr_disc,
|
||||
**c.optimizer_params,
|
||||
)
|
||||
else:
|
||||
optimizer_disc = optimizer_disc(model_disc.parameters(), lr=c.lr_disc, **c.optimizer_params)
|
||||
|
||||
# schedulers
|
||||
scheduler_gen = None
|
||||
scheduler_disc = None
|
||||
if "lr_scheduler_gen" in c:
|
||||
scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
|
||||
scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params)
|
||||
if "lr_scheduler_disc" in c:
|
||||
scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
|
||||
scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params)
|
||||
|
||||
if args.restore_path:
|
||||
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
|
||||
checkpoint = torch.load(args.restore_path, map_location="cpu")
|
||||
try:
|
||||
print(" > Restoring Generator Model...")
|
||||
model_gen.load_state_dict(checkpoint["model"])
|
||||
print(" > Restoring Generator Optimizer...")
|
||||
optimizer_gen.load_state_dict(checkpoint["optimizer"])
|
||||
print(" > Restoring Discriminator Model...")
|
||||
model_disc.load_state_dict(checkpoint["model_disc"])
|
||||
print(" > Restoring Discriminator Optimizer...")
|
||||
optimizer_disc.load_state_dict(checkpoint["optimizer_disc"])
|
||||
# restore schedulers if it is a continuing training.
|
||||
if args.continue_path != "":
|
||||
if "scheduler" in checkpoint and scheduler_gen is not None:
|
||||
print(" > Restoring Generator LR Scheduler...")
|
||||
scheduler_gen.load_state_dict(checkpoint["scheduler"])
|
||||
# NOTE: Not sure if necessary
|
||||
scheduler_gen.optimizer = optimizer_gen
|
||||
if "scheduler_disc" in checkpoint and scheduler_disc is not None:
|
||||
print(" > Restoring Discriminator LR Scheduler...")
|
||||
scheduler_disc.load_state_dict(checkpoint["scheduler_disc"])
|
||||
scheduler_disc.optimizer = optimizer_disc
|
||||
if c.lr_scheduler_disc == "ExponentialLR":
|
||||
scheduler_disc.last_epoch = checkpoint["epoch"]
|
||||
except RuntimeError:
|
||||
# restore only matching layers.
|
||||
print(" > Partial model initialization...")
|
||||
model_dict = model_gen.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
||||
model_gen.load_state_dict(model_dict)
|
||||
|
||||
model_dict = model_disc.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model_disc"], c)
|
||||
model_disc.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
# reset lr if not countinuining training.
|
||||
if args.continue_path == "":
|
||||
for group in optimizer_gen.param_groups:
|
||||
group["lr"] = c.lr_gen
|
||||
|
||||
for group in optimizer_disc.param_groups:
|
||||
group["lr"] = c.lr_disc
|
||||
|
||||
print(f" > Model restored from step {checkpoint['step']:d}", flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
model_gen = DDP_th(model_gen, device_ids=[args.rank])
|
||||
model_disc = DDP_th(model_disc, device_ids=[args.rank])
|
||||
|
||||
num_params = count_parameters(model_gen)
|
||||
print(" > Generator has {} parameters".format(num_params), flush=True)
|
||||
num_params = count_parameters(model_disc)
|
||||
print(" > Discriminator has {} parameters".format(num_params), flush=True)
|
||||
|
||||
if args.restore_step == 0 or not args.best_path:
|
||||
best_loss = float("inf")
|
||||
print(" > Starting with inf best loss.")
|
||||
else:
|
||||
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with best loss of {best_loss}.")
|
||||
keep_all_best = c.get("keep_all_best", False)
|
||||
keep_after = c.get("keep_after", 10000) # void if keep_all_best False
|
||||
|
||||
global_step = args.restore_step
|
||||
for epoch in range(0, c.epochs):
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
_, global_step = train(
|
||||
model_gen,
|
||||
criterion_gen,
|
||||
optimizer_gen,
|
||||
model_disc,
|
||||
criterion_disc,
|
||||
optimizer_disc,
|
||||
scheduler_gen,
|
||||
scheduler_disc,
|
||||
ap,
|
||||
global_step,
|
||||
epoch,
|
||||
)
|
||||
eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = eval_avg_loss_dict[c.target_loss]
|
||||
best_loss = save_best_model(
|
||||
target_loss,
|
||||
best_loss,
|
||||
model_gen,
|
||||
optimizer_gen,
|
||||
scheduler_gen,
|
||||
model_disc,
|
||||
optimizer_disc,
|
||||
scheduler_disc,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
keep_all_best=keep_all_best,
|
||||
keep_after=keep_after,
|
||||
model_losses=eval_avg_loss_dict,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
|
||||
try:
|
||||
main(args)
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
|
@ -1,431 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Trains WaveGrad vocoder models."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# DISTRIBUTED
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.utils.arguments import init_training
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.distribute import init_distributed
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.training import setup_torch_training_env
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
||||
from TTS.vocoder.utils.generic_utils import plot_results, setup_generator
|
||||
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
||||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
||||
|
||||
|
||||
def setup_loader(ap, is_val=False, verbose=False):
|
||||
if is_val and not c.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
dataset = WaveGradDataset(
|
||||
ap=ap,
|
||||
items=eval_data if is_val else train_data,
|
||||
seq_len=c.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=c.pad_short,
|
||||
conv_pad=c.conv_pad,
|
||||
is_training=not is_val,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
use_cache=c.use_cache,
|
||||
verbose=verbose,
|
||||
)
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=c.batch_size,
|
||||
shuffle=num_gpus <= 1,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
def format_data(data):
|
||||
# return a whole audio segment
|
||||
m, x = data
|
||||
x = x.unsqueeze(1)
|
||||
if use_cuda:
|
||||
m = m.cuda(non_blocking=True)
|
||||
x = x.cuda(non_blocking=True)
|
||||
return m, x
|
||||
|
||||
|
||||
def format_test_data(data):
|
||||
# return a whole audio segment
|
||||
m, x = data
|
||||
m = m[None, ...]
|
||||
x = x[None, None, ...]
|
||||
if use_cuda:
|
||||
m = m.cuda(non_blocking=True)
|
||||
x = x.cuda(non_blocking=True)
|
||||
return m, x
|
||||
|
||||
|
||||
def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch):
|
||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
end_time = time.time()
|
||||
c_logger.print_train_start()
|
||||
# setup noise schedule
|
||||
noise_schedule = c["train_noise_schedule"]
|
||||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
if hasattr(model, "module"):
|
||||
model.module.compute_noise_level(betas)
|
||||
else:
|
||||
model.compute_noise_level(betas)
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
m, x = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
||||
# compute noisy input
|
||||
if hasattr(model, "module"):
|
||||
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
||||
else:
|
||||
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
||||
|
||||
# forward pass
|
||||
noise_hat = model(x_noisy, m, noise_scale)
|
||||
|
||||
# compute losses
|
||||
loss = criterion(noise, noise_hat)
|
||||
loss_wavegrad_dict = {"wavegrad_loss": loss}
|
||||
|
||||
# check nan loss
|
||||
if torch.isnan(loss).any():
|
||||
raise RuntimeError(f"Detected NaN loss at step {global_step}.")
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# backward pass with loss scaling
|
||||
if c.mixed_precision:
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss.backward()
|
||||
grad_norm = torch.nn.utils.grad_clip_norm_(model.parameters(), c.clip_grad)
|
||||
optimizer.step()
|
||||
|
||||
# schedule update
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
|
||||
# disconnect loss values
|
||||
loss_dict = dict()
|
||||
for key, value in loss_wavegrad_dict.items():
|
||||
if isinstance(value, int):
|
||||
loss_dict[key] = value
|
||||
else:
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
# epoch/step timing
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# get current learning rates
|
||||
current_lr = list(optimizer.param_groups)[0]["lr"]
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values["avg_" + key] = value
|
||||
update_train_values["avg_loader_time"] = loader_time
|
||||
update_train_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
# print training stats
|
||||
if global_step % c.print_step == 0:
|
||||
log_dict = {
|
||||
"step_time": [step_time, 2],
|
||||
"loader_time": [loader_time, 4],
|
||||
"current_lr": current_lr,
|
||||
"grad_norm": grad_norm.item(),
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# plot step stats
|
||||
if global_step % 10 == 0:
|
||||
iter_stats = {"lr": current_lr, "grad_norm": grad_norm.item(), "step_time": step_time}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
|
||||
# save checkpoint
|
||||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
scheduler,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=loss_dict,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None,
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# print epoch stats
|
||||
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
||||
|
||||
# Plot Training Epoch Stats
|
||||
epoch_stats = {"epoch_time": epoch_time}
|
||||
epoch_stats.update(keep_avg.avg_values)
|
||||
if args.rank == 0:
|
||||
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
||||
# TODO: plot model stats
|
||||
if c.tb_model_param_stats and args.rank == 0:
|
||||
tb_logger.tb_model_weights(model, global_step)
|
||||
return keep_avg.avg_values, global_step
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model, criterion, ap, global_step, epoch):
|
||||
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
|
||||
model.eval()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
end_time = time.time()
|
||||
c_logger.print_eval_start()
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
m, x = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
|
||||
# compute noisy input
|
||||
if hasattr(model, "module"):
|
||||
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
||||
else:
|
||||
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
||||
|
||||
# forward pass
|
||||
noise_hat = model(x_noisy, m, noise_scale)
|
||||
|
||||
# compute losses
|
||||
loss = criterion(noise, noise_hat)
|
||||
loss_wavegrad_dict = {"wavegrad_loss": loss}
|
||||
|
||||
loss_dict = dict()
|
||||
for key, value in loss_wavegrad_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict[key] = value
|
||||
else:
|
||||
loss_dict[key] = value.item()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# update avg stats
|
||||
update_eval_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_eval_values["avg_" + key] = value
|
||||
update_eval_values["avg_loader_time"] = loader_time
|
||||
update_eval_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_eval_values)
|
||||
|
||||
# print eval stats
|
||||
if c.print_eval:
|
||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
data_loader.dataset.return_segments = False
|
||||
samples = data_loader.dataset.load_test_samples(1)
|
||||
m, x = format_test_data(samples[0])
|
||||
|
||||
# setup noise schedule and inference
|
||||
noise_schedule = c["test_noise_schedule"]
|
||||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
if hasattr(model, "module"):
|
||||
model.module.compute_noise_level(betas)
|
||||
# compute voice
|
||||
x_pred = model.module.inference(m)
|
||||
else:
|
||||
model.compute_noise_level(betas)
|
||||
# compute voice
|
||||
x_pred = model.inference(m)
|
||||
|
||||
# compute spectrograms
|
||||
figures = plot_results(x_pred, x, ap, global_step, "eval")
|
||||
tb_logger.tb_eval_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
sample_voice = x_pred[0].squeeze(0).detach().cpu().numpy()
|
||||
tb_logger.tb_eval_audios(global_step, {"eval/audio": sample_voice}, c.audio["sample_rate"])
|
||||
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
data_loader.dataset.return_segments = True
|
||||
|
||||
return keep_avg.avg_values
|
||||
|
||||
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global train_data, eval_data
|
||||
print(f" > Loading wavs from: {c.data_path}")
|
||||
if c.feature_path is not None:
|
||||
print(f" > Loading features from: {c.feature_path}")
|
||||
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
|
||||
else:
|
||||
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
|
||||
|
||||
# setup audio processor
|
||||
ap = AudioProcessor(**c.audio.to_dict())
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"])
|
||||
|
||||
# setup models
|
||||
model = setup_generator(c)
|
||||
|
||||
# scaler for mixed_precision
|
||||
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
||||
|
||||
# setup optimizers
|
||||
optimizer = Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
||||
|
||||
# schedulers
|
||||
scheduler = None
|
||||
if "lr_scheduler" in c:
|
||||
scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
|
||||
scheduler = scheduler(optimizer, **c.lr_scheduler_params)
|
||||
|
||||
# setup criterion
|
||||
criterion = torch.nn.L1Loss().cuda()
|
||||
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
criterion.cuda()
|
||||
|
||||
if args.restore_path:
|
||||
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
|
||||
checkpoint = torch.load(args.restore_path, map_location="cpu")
|
||||
try:
|
||||
print(" > Restoring Model...")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
print(" > Restoring Optimizer...")
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
if "scheduler" in checkpoint:
|
||||
print(" > Restoring LR Scheduler...")
|
||||
scheduler.load_state_dict(checkpoint["scheduler"])
|
||||
# NOTE: Not sure if necessary
|
||||
scheduler.optimizer = optimizer
|
||||
if "scaler" in checkpoint and c.mixed_precision:
|
||||
print(" > Restoring AMP Scaler...")
|
||||
scaler.load_state_dict(checkpoint["scaler"])
|
||||
except RuntimeError:
|
||||
# retore only matching layers.
|
||||
print(" > Partial model initialization...")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
# reset lr if not countinuining training.
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = c.lr
|
||||
|
||||
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
model = DDP_th(model, device_ids=[args.rank])
|
||||
|
||||
num_params = count_parameters(model)
|
||||
print(" > WaveGrad has {} parameters".format(num_params), flush=True)
|
||||
|
||||
if args.restore_step == 0 or not args.best_path:
|
||||
best_loss = float("inf")
|
||||
print(" > Starting with inf best loss.")
|
||||
else:
|
||||
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {best_loss}.")
|
||||
keep_all_best = c.get("keep_all_best", False)
|
||||
keep_after = c.get("keep_after", 10000) # void if keep_all_best False
|
||||
|
||||
global_step = args.restore_step
|
||||
for epoch in range(0, c.epochs):
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
_, global_step = train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch)
|
||||
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = eval_avg_loss_dict[c.target_loss]
|
||||
best_loss = save_best_model(
|
||||
target_loss,
|
||||
best_loss,
|
||||
model,
|
||||
optimizer,
|
||||
scheduler,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
keep_all_best=keep_all_best,
|
||||
keep_after=keep_after,
|
||||
model_losses=eval_avg_loss_dict,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
|
||||
try:
|
||||
main(args)
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
|
@ -1,431 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Train WaveRNN vocoder model."""
|
||||
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from TTS.tts.utils.visual import plot_spectrogram
|
||||
from TTS.utils.arguments import init_training
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.training import setup_torch_training_env
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
||||
from TTS.vocoder.utils.generic_utils import setup_generator
|
||||
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
||||
|
||||
# from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, True)
|
||||
|
||||
|
||||
def setup_loader(ap, is_val=False, verbose=False):
|
||||
if is_val and not c.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
dataset = WaveRNNDataset(
|
||||
ap=ap,
|
||||
items=eval_data if is_val else train_data,
|
||||
seq_len=c.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad=c.padding,
|
||||
mode=c.mode,
|
||||
mulaw=c.mulaw,
|
||||
is_training=not is_val,
|
||||
verbose=verbose,
|
||||
)
|
||||
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
collate_fn=dataset.collate,
|
||||
batch_size=c.batch_size,
|
||||
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
return loader
|
||||
|
||||
|
||||
def format_data(data):
|
||||
# setup input data
|
||||
x_input = data[0]
|
||||
mels = data[1]
|
||||
y_coarse = data[2]
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
x_input = x_input.cuda(non_blocking=True)
|
||||
mels = mels.cuda(non_blocking=True)
|
||||
y_coarse = y_coarse.cuda(non_blocking=True)
|
||||
|
||||
return x_input, mels, y_coarse
|
||||
|
||||
|
||||
def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch):
|
||||
# create train loader
|
||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
end_time = time.time()
|
||||
c_logger.print_train_start()
|
||||
# train loop
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
x_input, mels, y_coarse = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
global_step += 1
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if c.mixed_precision:
|
||||
# mixed precision training
|
||||
with torch.cuda.amp.autocast():
|
||||
y_hat = model(x_input, mels)
|
||||
if isinstance(model.mode, int):
|
||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||
else:
|
||||
y_coarse = y_coarse.float()
|
||||
y_coarse = y_coarse.unsqueeze(-1)
|
||||
# compute losses
|
||||
loss = criterion(y_hat, y_coarse)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
if c.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
# full precision training
|
||||
y_hat = model(x_input, mels)
|
||||
if isinstance(model.mode, int):
|
||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||
else:
|
||||
y_coarse = y_coarse.float()
|
||||
y_coarse = y_coarse.unsqueeze(-1)
|
||||
# compute losses
|
||||
loss = criterion(y_hat, y_coarse)
|
||||
if loss.item() is None:
|
||||
raise RuntimeError(" [!] None loss. Exiting ...")
|
||||
loss.backward()
|
||||
if c.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
|
||||
# get the current learning rate
|
||||
cur_lr = list(optimizer.param_groups)[0]["lr"]
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
update_train_values = dict()
|
||||
loss_dict = dict()
|
||||
loss_dict["model_loss"] = loss.item()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values["avg_" + key] = value
|
||||
update_train_values["avg_loader_time"] = loader_time
|
||||
update_train_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
# print training stats
|
||||
if global_step % c.print_step == 0:
|
||||
log_dict = {
|
||||
"step_time": [step_time, 2],
|
||||
"loader_time": [loader_time, 4],
|
||||
"current_lr": cur_lr,
|
||||
}
|
||||
c_logger.print_train_step(
|
||||
batch_n_iter,
|
||||
num_iter,
|
||||
global_step,
|
||||
log_dict,
|
||||
loss_dict,
|
||||
keep_avg.avg_values,
|
||||
)
|
||||
|
||||
# plot step stats
|
||||
if global_step % 10 == 0:
|
||||
iter_stats = {"lr": cur_lr, "step_time": step_time}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
|
||||
# save checkpoint
|
||||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
scheduler,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=loss_dict,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None,
|
||||
)
|
||||
|
||||
# synthesize a full voice
|
||||
rand_idx = random.randrange(0, len(train_data))
|
||||
wav_path = (
|
||||
train_data[rand_idx] if not isinstance(train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0]
|
||||
)
|
||||
wav = ap.load_wav(wav_path)
|
||||
ground_mel = ap.melspectrogram(wav)
|
||||
ground_mel = torch.FloatTensor(ground_mel)
|
||||
if use_cuda:
|
||||
ground_mel = ground_mel.cuda(non_blocking=True)
|
||||
sample_wav = model.inference(
|
||||
ground_mel,
|
||||
c.batched,
|
||||
c.target_samples,
|
||||
c.overlap_samples,
|
||||
)
|
||||
predict_mel = ap.melspectrogram(sample_wav)
|
||||
|
||||
# compute spectrograms
|
||||
figures = {
|
||||
"train/ground_truth": plot_spectrogram(ground_mel.T),
|
||||
"train/prediction": plot_spectrogram(predict_mel.T),
|
||||
}
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
tb_logger.tb_train_audios(global_step, {"train/audio": sample_wav}, c.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
# print epoch stats
|
||||
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
||||
|
||||
# Plot Training Epoch Stats
|
||||
epoch_stats = {"epoch_time": epoch_time}
|
||||
epoch_stats.update(keep_avg.avg_values)
|
||||
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
||||
# TODO: plot model stats
|
||||
# if c.tb_model_param_stats:
|
||||
# tb_logger.tb_model_weights(model, global_step)
|
||||
return keep_avg.avg_values, global_step
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model, criterion, ap, global_step, epoch):
|
||||
# create train loader
|
||||
data_loader = setup_loader(ap, is_val=True, verbose=(epoch == 0))
|
||||
model.eval()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
end_time = time.time()
|
||||
c_logger.print_eval_start()
|
||||
with torch.no_grad():
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
# format data
|
||||
x_input, mels, y_coarse = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
global_step += 1
|
||||
|
||||
y_hat = model(x_input, mels)
|
||||
if isinstance(model.mode, int):
|
||||
y_hat = y_hat.transpose(1, 2).unsqueeze(-1)
|
||||
else:
|
||||
y_coarse = y_coarse.float()
|
||||
y_coarse = y_coarse.unsqueeze(-1)
|
||||
loss = criterion(y_hat, y_coarse)
|
||||
# Compute avg loss
|
||||
# if num_gpus > 1:
|
||||
# loss = reduce_tensor(loss.data, num_gpus)
|
||||
loss_dict = dict()
|
||||
loss_dict["model_loss"] = loss.item()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# update avg stats
|
||||
update_eval_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_eval_values["avg_" + key] = value
|
||||
update_eval_values["avg_loader_time"] = loader_time
|
||||
update_eval_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_eval_values)
|
||||
|
||||
# print eval stats
|
||||
if c.print_eval:
|
||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if epoch % c.test_every_epochs == 0 and epoch != 0:
|
||||
# synthesize a full voice
|
||||
rand_idx = random.randrange(0, len(eval_data))
|
||||
wav_path = eval_data[rand_idx] if not isinstance(eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0]
|
||||
wav = ap.load_wav(wav_path)
|
||||
ground_mel = ap.melspectrogram(wav)
|
||||
ground_mel = torch.FloatTensor(ground_mel)
|
||||
if use_cuda:
|
||||
ground_mel = ground_mel.cuda(non_blocking=True)
|
||||
sample_wav = model.inference(
|
||||
ground_mel,
|
||||
c.batched,
|
||||
c.target_samples,
|
||||
c.overlap_samples,
|
||||
)
|
||||
predict_mel = ap.melspectrogram(sample_wav)
|
||||
|
||||
# Sample audio
|
||||
tb_logger.tb_eval_audios(global_step, {"eval/audio": sample_wav}, c.audio["sample_rate"])
|
||||
|
||||
# compute spectrograms
|
||||
figures = {
|
||||
"eval/ground_truth": plot_spectrogram(ground_mel.T),
|
||||
"eval/prediction": plot_spectrogram(predict_mel.T),
|
||||
}
|
||||
tb_logger.tb_eval_figures(global_step, figures)
|
||||
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
return keep_avg.avg_values
|
||||
|
||||
|
||||
# FIXME: move args definition/parsing inside of main?
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global train_data, eval_data
|
||||
|
||||
# setup audio processor
|
||||
ap = AudioProcessor(**c.audio.to_dict())
|
||||
|
||||
print(f" > Loading wavs from: {c.data_path}")
|
||||
if c.feature_path is not None:
|
||||
print(f" > Loading features from: {c.feature_path}")
|
||||
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
|
||||
else:
|
||||
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
|
||||
# setup model
|
||||
model_wavernn = setup_generator(c)
|
||||
|
||||
# setup amp scaler
|
||||
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
||||
|
||||
# define train functions
|
||||
if c.mode == "mold":
|
||||
criterion = discretized_mix_logistic_loss
|
||||
elif c.mode == "gauss":
|
||||
criterion = gaussian_loss
|
||||
elif isinstance(c.mode, int):
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
if use_cuda:
|
||||
model_wavernn.cuda()
|
||||
if isinstance(c.mode, int):
|
||||
criterion.cuda()
|
||||
|
||||
optimizer = RAdam(model_wavernn.parameters(), lr=c.lr, weight_decay=0)
|
||||
|
||||
scheduler = None
|
||||
if "lr_scheduler" in c:
|
||||
scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
|
||||
scheduler = scheduler(optimizer, **c.lr_scheduler_params)
|
||||
# slow start for the first 5 epochs
|
||||
# lr_lambda = lambda epoch: min(epoch / c.warmup_steps, 1)
|
||||
# scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
|
||||
|
||||
# restore any checkpoint
|
||||
if args.restore_path:
|
||||
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
|
||||
checkpoint = torch.load(args.restore_path, map_location="cpu")
|
||||
try:
|
||||
print(" > Restoring Model...")
|
||||
model_wavernn.load_state_dict(checkpoint["model"])
|
||||
print(" > Restoring Optimizer...")
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
if "scheduler" in checkpoint:
|
||||
print(" > Restoring Generator LR Scheduler...")
|
||||
scheduler.load_state_dict(checkpoint["scheduler"])
|
||||
scheduler.optimizer = optimizer
|
||||
if "scaler" in checkpoint and c.mixed_precision:
|
||||
print(" > Restoring AMP Scaler...")
|
||||
scaler.load_state_dict(checkpoint["scaler"])
|
||||
except RuntimeError:
|
||||
# retore only matching layers.
|
||||
print(" > Partial model initialization...")
|
||||
model_dict = model_wavernn.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
||||
model_wavernn.load_state_dict(model_dict)
|
||||
|
||||
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
# DISTRIBUTED
|
||||
# if num_gpus > 1:
|
||||
# model = apply_gradient_allreduce(model)
|
||||
|
||||
num_parameters = count_parameters(model_wavernn)
|
||||
print(" > Model has {} parameters".format(num_parameters), flush=True)
|
||||
|
||||
if args.restore_step == 0 or not args.best_path:
|
||||
best_loss = float("inf")
|
||||
print(" > Starting with inf best loss.")
|
||||
else:
|
||||
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {best_loss}.")
|
||||
keep_all_best = c.get("keep_all_best", False)
|
||||
keep_after = c.get("keep_after", 10000) # void if keep_all_best False
|
||||
|
||||
global_step = args.restore_step
|
||||
for epoch in range(0, c.epochs):
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
_, global_step = train(model_wavernn, optimizer, criterion, scheduler, scaler, ap, global_step, epoch)
|
||||
eval_avg_loss_dict = evaluate(model_wavernn, criterion, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = eval_avg_loss_dict["avg_model_loss"]
|
||||
best_loss = save_best_model(
|
||||
target_loss,
|
||||
best_loss,
|
||||
model_wavernn,
|
||||
optimizer,
|
||||
scheduler,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
keep_all_best=keep_all_best,
|
||||
keep_after=keep_after,
|
||||
model_losses=eval_avg_loss_dict,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args, c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = init_training(sys.argv)
|
||||
try:
|
||||
main(args)
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
|
@ -1,8 +1,10 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Dict
|
||||
|
||||
import yaml
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.config.shared_configs import *
|
||||
from TTS.utils.generic_utils import find_module
|
||||
|
@ -20,7 +22,18 @@ def read_json_with_comments(json_path):
|
|||
return data
|
||||
|
||||
|
||||
def _search_configs(model_name):
|
||||
def register_config(model_name: str) -> Coqpit:
|
||||
"""Find the right config for the given model name.
|
||||
|
||||
Args:
|
||||
model_name (str): Model name.
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: No matching config for the model name.
|
||||
|
||||
Returns:
|
||||
Coqpit: config class.
|
||||
"""
|
||||
config_class = None
|
||||
paths = ["TTS.tts.configs", "TTS.vocoder.configs", "TTS.speaker_encoder"]
|
||||
for path in paths:
|
||||
|
@ -33,7 +46,15 @@ def _search_configs(model_name):
|
|||
return config_class
|
||||
|
||||
|
||||
def _process_model_name(config_dict):
|
||||
def _process_model_name(config_dict: Dict) -> str:
|
||||
"""Format the model name as expected. It is a band-aid for the old `vocoder` model names.
|
||||
|
||||
Args:
|
||||
config_dict (Dict): A dictionary including the config fields.
|
||||
|
||||
Returns:
|
||||
str: Formatted modelname.
|
||||
"""
|
||||
model_name = config_dict["model"] if "model" in config_dict else config_dict["generator_model"]
|
||||
model_name = model_name.replace("_generator", "").replace("_discriminator", "")
|
||||
return model_name
|
||||
|
@ -69,7 +90,7 @@ def load_config(config_path: str) -> None:
|
|||
raise TypeError(f" [!] Unknown config file type {ext}")
|
||||
config_dict.update(data)
|
||||
model_name = _process_model_name(config_dict)
|
||||
config_class = _search_configs(model_name.lower())
|
||||
config_class = register_config(model_name.lower())
|
||||
config = config_class()
|
||||
config.from_dict(config_dict)
|
||||
return config
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from dataclasses import asdict, dataclass
|
||||
from typing import List
|
||||
|
||||
from coqpit import MISSING, Coqpit, check_argument
|
||||
from coqpit import Coqpit, check_argument
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -180,6 +180,14 @@ class BaseTrainingConfig(Coqpit):
|
|||
among all the models.
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Name of the model that is used in the training.
|
||||
run_name (str):
|
||||
Name of the experiment. This prefixes the output folder name.
|
||||
run_description (str):
|
||||
Short description of the experiment.
|
||||
epochs (int):
|
||||
Number training epochs. Defaults to 10000.
|
||||
batch_size (int):
|
||||
Training batch size.
|
||||
eval_batch_size (int):
|
||||
|
@ -214,7 +222,7 @@ class BaseTrainingConfig(Coqpit):
|
|||
to 10000.
|
||||
num_loader_workers (int):
|
||||
Number of workers for training time dataloader.
|
||||
num_val_loader_workers (int):
|
||||
num_eval_loader_workers (int):
|
||||
Number of workers for evaluation time dataloader.
|
||||
output_path (str):
|
||||
Path for training output folder. The nonexist part of the given path is created automatically.
|
||||
|
@ -243,8 +251,8 @@ class BaseTrainingConfig(Coqpit):
|
|||
keep_all_best: bool = False
|
||||
keep_after: int = 10000
|
||||
# dataloading
|
||||
num_loader_workers: int = MISSING
|
||||
num_val_loader_workers: int = 0
|
||||
num_loader_workers: int = None
|
||||
num_eval_loader_workers: int = 0
|
||||
use_noise_augment: bool = False
|
||||
# paths
|
||||
output_path: str = None
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
|
||||
class BaseModel(nn.Module, ABC):
|
||||
"""Abstract 🐸TTS class. Every new 🐸TTS model must inherit this.
|
||||
|
||||
Notes on input/output tensor shapes:
|
||||
Any input or output tensor of the model must be shaped as
|
||||
|
||||
- 3D tensors `batch x time x channels`
|
||||
- 2D tensors `batch x channels`
|
||||
- 1D tensors `batch x 1`
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, text: torch.Tensor, aux_input={}, **kwargs) -> Dict:
|
||||
"""Forward pass for the model mainly used in training.
|
||||
|
||||
You can be flexible here and use different number of arguments and argument names since it is mostly used by
|
||||
`train_step()` in training whitout exposing it to the out of the class.
|
||||
|
||||
Args:
|
||||
text (torch.Tensor): Input text character sequence ids.
|
||||
aux_input (Dict): Auxiliary model inputs like embeddings, durations or any other sorts of inputs.
|
||||
for the model.
|
||||
|
||||
Returns:
|
||||
Dict: model outputs. This must include an item keyed `model_outputs` as the final artifact of the model.
|
||||
"""
|
||||
outputs_dict = {"model_outputs": None}
|
||||
...
|
||||
return outputs_dict
|
||||
|
||||
@abstractmethod
|
||||
def inference(self, text: torch.Tensor, aux_input={}) -> Dict:
|
||||
"""Forward pass for inference.
|
||||
|
||||
After the model is trained this is the only function that connects the model the out world.
|
||||
|
||||
This function must only take a `text` input and a dictionary that has all the other model specific inputs.
|
||||
We don't use `*kwargs` since it is problematic with the TorchScript API.
|
||||
|
||||
Args:
|
||||
text (torch.Tensor): [description]
|
||||
aux_input (Dict): Auxiliary inputs like speaker embeddings, durations etc.
|
||||
|
||||
Returns:
|
||||
Dict: [description]
|
||||
"""
|
||||
outputs_dict = {"model_outputs": None}
|
||||
...
|
||||
return outputs_dict
|
||||
|
||||
@abstractmethod
|
||||
def train_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||
"""Perform a single training step. Run the model forward pass and compute losses.
|
||||
|
||||
Args:
|
||||
batch (Dict): Input tensors.
|
||||
criterion (nn.Module): Loss layer designed for the model.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: Model ouputs and computed losses.
|
||||
"""
|
||||
outputs_dict = {}
|
||||
loss_dict = {} # this returns from the criterion
|
||||
...
|
||||
return outputs_dict, loss_dict
|
||||
|
||||
def train_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
"""Create visualizations and waveform examples for training.
|
||||
|
||||
For example, here you can plot spectrograms and generate sample sample waveforms from these spectrograms to
|
||||
be projected onto Tensorboard.
|
||||
|
||||
Args:
|
||||
ap (AudioProcessor): audio processor used at training.
|
||||
batch (Dict): Model inputs used at the previous training step.
|
||||
outputs (Dict): Model outputs generated at the previoud training step.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, np.ndarray]: training plots and output waveform.
|
||||
"""
|
||||
return None, None
|
||||
|
||||
@abstractmethod
|
||||
def eval_step(self, batch: Dict, criterion: nn.Module) -> Tuple[Dict, Dict]:
|
||||
"""Perform a single evaluation step. Run the model forward pass and compute losses. In most cases, you can
|
||||
call `train_step()` with no changes.
|
||||
|
||||
Args:
|
||||
batch (Dict): Input tensors.
|
||||
criterion (nn.Module): Loss layer designed for the model.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: Model ouputs and computed losses.
|
||||
"""
|
||||
outputs_dict = {}
|
||||
loss_dict = {} # this returns from the criterion
|
||||
...
|
||||
return outputs_dict, loss_dict
|
||||
|
||||
def eval_log(self, ap: AudioProcessor, batch: Dict, outputs: Dict) -> Tuple[Dict, np.ndarray]:
|
||||
"""The same as `train_log()`"""
|
||||
return None, None
|
||||
|
||||
@abstractmethod
|
||||
def load_checkpoint(self, config: Coqpit, checkpoint_path: str, eval: bool = False) -> None:
|
||||
"""Load a checkpoint and get ready for training or inference.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
checkpoint_path (str): Path to the model checkpoint file.
|
||||
eval (bool, optional): If true, init model for inference else for training. Defaults to False.
|
||||
"""
|
||||
...
|
||||
|
||||
def get_optimizer(self) -> Union["Optimizer", List["Optimizer"]]:
|
||||
"""Setup an return optimizer or optimizers."""
|
||||
pass
|
||||
|
||||
def get_lr(self) -> Union[float, List[float]]:
|
||||
"""Return learning rate(s).
|
||||
|
||||
Returns:
|
||||
Union[float, List[float]]: Model's initial learning rates.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_scheduler(self, optimizer: torch.optim.Optimizer):
|
||||
pass
|
||||
|
||||
def get_criterion(self):
|
||||
pass
|
||||
|
||||
def format_batch(self):
|
||||
pass
|
File diff suppressed because it is too large
Load Diff
|
@ -2,6 +2,7 @@ from dataclasses import dataclass, field
|
|||
from typing import List
|
||||
|
||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||
from TTS.tts.models.align_tts import AlignTTSArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -49,9 +50,9 @@ class AlignTTSConfig(BaseTTSConfig):
|
|||
use_speaker_embedding (bool):
|
||||
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
|
||||
in the multi-speaker mode. Defaults to False.
|
||||
use_external_speaker_embedding_file (bool):
|
||||
use_d_vector_file (bool):
|
||||
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
|
||||
external_speaker_embedding_file (str):
|
||||
d_vector_file (str):
|
||||
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
||||
noam_schedule (bool):
|
||||
enable / disable the use of Noam LR scheduler. Defaults to False.
|
||||
|
@ -68,17 +69,7 @@ class AlignTTSConfig(BaseTTSConfig):
|
|||
|
||||
model: str = "align_tts"
|
||||
# model specific params
|
||||
positional_encoding: bool = True
|
||||
hidden_channels_dp: int = 256
|
||||
hidden_channels: int = 256
|
||||
encoder_type: str = "fftransformer"
|
||||
encoder_params: dict = field(
|
||||
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}
|
||||
)
|
||||
decoder_type: str = "fftransformer"
|
||||
decoder_params: dict = field(
|
||||
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}
|
||||
)
|
||||
model_args: AlignTTSArgs = field(default_factory=AlignTTSArgs)
|
||||
phase_start_steps: List[int] = None
|
||||
|
||||
ssim_alpha: float = 1.0
|
||||
|
@ -88,17 +79,29 @@ class AlignTTSConfig(BaseTTSConfig):
|
|||
|
||||
# multi-speaker settings
|
||||
use_speaker_embedding: bool = False
|
||||
use_external_speaker_embedding_file: bool = False
|
||||
external_speaker_embedding_file: str = False
|
||||
use_d_vector_file: bool = False
|
||||
d_vector_file: str = False
|
||||
|
||||
# optimizer parameters
|
||||
noam_schedule: bool = False
|
||||
warmup_steps: int = 4000
|
||||
optimizer: str = "Adam"
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
|
||||
lr_scheduler: str = None
|
||||
lr_scheduler_params: dict = None
|
||||
lr: float = 1e-4
|
||||
wd: float = 1e-6
|
||||
grad_clip: float = 5.0
|
||||
|
||||
# overrides
|
||||
min_seq_len: int = 13
|
||||
max_seq_len: int = 200
|
||||
r: int = 1
|
||||
|
||||
# testing
|
||||
test_sentences: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963.",
|
||||
]
|
||||
)
|
||||
|
|
|
@ -7,7 +7,7 @@ from TTS.tts.configs.shared_configs import BaseTTSConfig
|
|||
class GlowTTSConfig(BaseTTSConfig):
|
||||
"""Defines parameters for GlowTTS model.
|
||||
|
||||
Example:
|
||||
Example:
|
||||
|
||||
>>> from TTS.tts.configs import GlowTTSConfig
|
||||
>>> config = GlowTTSConfig()
|
||||
|
@ -23,13 +23,49 @@ class GlowTTSConfig(BaseTTSConfig):
|
|||
Defaults to `{"kernel_size": 3, "dropout_p": 0.1, "num_layers": 6, "num_heads": 2, "hidden_channels_ffn": 768}`
|
||||
use_encoder_prenet (bool):
|
||||
enable / disable the use of a prenet for the encoder. Defaults to True.
|
||||
hidden_channels_encoder (int):
|
||||
hidden_channels_enc (int):
|
||||
Number of base hidden channels used by the encoder network. It defines the input and the output channel sizes,
|
||||
and for some encoder types internal hidden channels sizes too. Defaults to 192.
|
||||
hidden_channels_decoder (int):
|
||||
hidden_channels_dec (int):
|
||||
Number of base hidden channels used by the decoder WaveNet network. Defaults to 192 as in the original work.
|
||||
hidden_channels_duration_predictor (int):
|
||||
hidden_channels_dp (int):
|
||||
Number of layer channels of the duration predictor network. Defaults to 256 as in the original work.
|
||||
mean_only (bool):
|
||||
If true predict only the mean values by the decoder flow. Defaults to True.
|
||||
out_channels (int):
|
||||
Number of channels of the model output tensor. Defaults to 80.
|
||||
num_flow_blocks_dec (int):
|
||||
Number of decoder blocks. Defaults to 12.
|
||||
inference_noise_scale (float):
|
||||
Noise scale used at inference. Defaults to 0.33.
|
||||
kernel_size_dec (int):
|
||||
Decoder kernel size. Defaults to 5
|
||||
dilation_rate (int):
|
||||
Rate to increase dilation by each layer in a decoder block. Defaults to 1.
|
||||
num_block_layers (int):
|
||||
Number of decoder layers in each decoder block. Defaults to 4.
|
||||
dropout_p_dec (float):
|
||||
Dropout rate for decoder. Defaults to 0.1.
|
||||
num_speaker (int):
|
||||
Number of speaker to define the size of speaker embedding layer. Defaults to 0.
|
||||
c_in_channels (int):
|
||||
Number of speaker embedding channels. It is set to 512 if embeddings are learned. Defaults to 0.
|
||||
num_splits (int):
|
||||
Number of split levels in inversible conv1x1 operation. Defaults to 4.
|
||||
num_squeeze (int):
|
||||
Number of squeeze levels. When squeezing channels increases and time steps reduces by the factor
|
||||
'num_squeeze'. Defaults to 2.
|
||||
sigmoid_scale (bool):
|
||||
enable/disable sigmoid scaling in decoder. Defaults to False.
|
||||
mean_only (bool):
|
||||
If True, encoder only computes mean value and uses constant variance for each time step. Defaults to true.
|
||||
encoder_type (str):
|
||||
Encoder module type. Possible values are`["rel_pos_transformer", "gated_conv", "residual_conv_bn", "time_depth_separable"]`
|
||||
Check `TTS.tts.layers.glow_tts.encoder` for more details. Defaults to `rel_pos_transformers` as in the original paper.
|
||||
encoder_params (dict):
|
||||
Encoder module parameters. Defaults to None.
|
||||
d_vector_dim (int):
|
||||
Channels of external speaker embedding vectors. Defaults to 0.
|
||||
data_dep_init_steps (int):
|
||||
Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses
|
||||
Activation Normalization that pre-computes normalization stats at the beginning and use the same values
|
||||
|
@ -38,12 +74,14 @@ class GlowTTSConfig(BaseTTSConfig):
|
|||
Path to the wav file used for changing the style of the speech. Defaults to None.
|
||||
inference_noise_scale (float):
|
||||
Variance used for sampling the random noise added to the decoder's input at inference. Defaults to 0.0.
|
||||
length_scale (float):
|
||||
Multiply the predicted durations with this value to change the speech speed. Defaults to 1.
|
||||
use_speaker_embedding (bool):
|
||||
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
|
||||
in the multi-speaker mode. Defaults to False.
|
||||
use_external_speaker_embedding_file (bool):
|
||||
use_d_vector_file (bool):
|
||||
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
|
||||
external_speaker_embedding_file (str):
|
||||
d_vector_file (str):
|
||||
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
||||
noam_schedule (bool):
|
||||
enable / disable the use of Noam LR scheduler. Defaults to False.
|
||||
|
@ -62,6 +100,7 @@ class GlowTTSConfig(BaseTTSConfig):
|
|||
model: str = "glow_tts"
|
||||
|
||||
# model params
|
||||
num_chars: int = None
|
||||
encoder_type: str = "rel_pos_transformer"
|
||||
encoder_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
|
@ -73,9 +112,35 @@ class GlowTTSConfig(BaseTTSConfig):
|
|||
}
|
||||
)
|
||||
use_encoder_prenet: bool = True
|
||||
hidden_channels_encoder: int = 192
|
||||
hidden_channels_decoder: int = 192
|
||||
hidden_channels_duration_predictor: int = 256
|
||||
hidden_channels_enc: int = 192
|
||||
hidden_channels_dec: int = 192
|
||||
hidden_channels_dp: int = 256
|
||||
dropout_p_dp: float = 0.1
|
||||
dropout_p_dec: float = 0.05
|
||||
mean_only: bool = True
|
||||
out_channels: int = 80
|
||||
num_flow_blocks_dec: int = 12
|
||||
inference_noise_scale: float = 0.33
|
||||
kernel_size_dec: int = 5
|
||||
dilation_rate: int = 1
|
||||
num_block_layers: int = 4
|
||||
num_speakers: int = 0
|
||||
c_in_channels: int = 0
|
||||
num_splits: int = 4
|
||||
num_squeeze: int = 2
|
||||
sigmoid_scale: bool = False
|
||||
encoder_type: str = "rel_pos_transformer"
|
||||
encoder_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"kernel_size": 3,
|
||||
"dropout_p": 0.1,
|
||||
"num_layers": 6,
|
||||
"num_heads": 2,
|
||||
"hidden_channels_ffn": 768,
|
||||
"input_length": None,
|
||||
}
|
||||
)
|
||||
d_vector_dim: int = 0
|
||||
|
||||
# training params
|
||||
data_dep_init_steps: int = 10
|
||||
|
@ -83,18 +148,20 @@ class GlowTTSConfig(BaseTTSConfig):
|
|||
# inference params
|
||||
style_wav_for_test: str = None
|
||||
inference_noise_scale: float = 0.0
|
||||
length_scale: float = 1.0
|
||||
|
||||
# multi-speaker settings
|
||||
use_speaker_embedding: bool = False
|
||||
use_external_speaker_embedding_file: bool = False
|
||||
external_speaker_embedding_file: str = False
|
||||
use_d_vector_file: bool = False
|
||||
d_vector_file: str = False
|
||||
|
||||
# optimizer params
|
||||
noam_schedule: bool = True
|
||||
warmup_steps: int = 4000
|
||||
# optimizer parameters
|
||||
optimizer: str = "RAdam"
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
|
||||
lr_scheduler: str = "NoamLR"
|
||||
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
|
||||
grad_clip: float = 5.0
|
||||
lr: float = 1e-3
|
||||
wd: float = 0.000001
|
||||
|
||||
# overrides
|
||||
min_seq_len: int = 3
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from dataclasses import asdict, dataclass, field
|
||||
from typing import List
|
||||
|
||||
from coqpit import MISSING, Coqpit, check_argument
|
||||
from coqpit import Coqpit, check_argument
|
||||
|
||||
from TTS.config import BaseAudioConfig, BaseDatasetConfig, BaseTrainingConfig
|
||||
|
||||
|
@ -133,6 +133,18 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
datasets (List[BaseDatasetConfig]):
|
||||
List of datasets used for training. If multiple datasets are provided, they are merged and used together
|
||||
for training.
|
||||
optimizer (str):
|
||||
Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
|
||||
Defaults to ``.
|
||||
optimizer_params (dict):
|
||||
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
||||
lr_scheduler (str):
|
||||
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
|
||||
`TTS.utils.training`. Defaults to ``.
|
||||
lr_scheduler_params (dict):
|
||||
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
|
||||
test_sentences (List[str]):
|
||||
List of sentences to be used at testing. Defaults to '[]'
|
||||
"""
|
||||
|
||||
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
||||
|
@ -141,7 +153,7 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
use_espeak_phonemes: bool = True
|
||||
phoneme_language: str = None
|
||||
compute_input_seq_cache: bool = False
|
||||
text_cleaner: str = MISSING
|
||||
text_cleaner: str = None
|
||||
enable_eos_bos_chars: bool = False
|
||||
test_sentences_file: str = ""
|
||||
phoneme_cache_path: str = None
|
||||
|
@ -158,3 +170,15 @@ class BaseTTSConfig(BaseTrainingConfig):
|
|||
add_blank: bool = False
|
||||
# dataset
|
||||
datasets: List[BaseDatasetConfig] = field(default_factory=lambda: [BaseDatasetConfig()])
|
||||
# optimizer
|
||||
optimizer: str = None
|
||||
optimizer_params: dict = None
|
||||
# scheduler
|
||||
lr_scheduler: str = ""
|
||||
lr_scheduler_params: dict = field(default_factory=lambda: {})
|
||||
# testing
|
||||
test_sentences: List[str] = field(default_factory=lambda: [])
|
||||
# multi-speaker
|
||||
use_speaker_embedding: bool = False
|
||||
use_d_vector_file: bool = False
|
||||
d_vector_dim: int = 0
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from TTS.tts.configs.shared_configs import BaseTTSConfig
|
||||
from TTS.tts.models.speedy_speech import SpeedySpeechArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -15,30 +17,8 @@ class SpeedySpeechConfig(BaseTTSConfig):
|
|||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `speedy_speech`.
|
||||
positional_encoding (bool):
|
||||
enable / disable positional encoding applied to the encoder output. Defaults to True.
|
||||
hidden_channels (int):
|
||||
Base number of hidden channels. Defines all the layers expect ones defined by the specific encoder or decoder
|
||||
parameters. Defaults to 128.
|
||||
encoder_type (str):
|
||||
Type of the encoder used by the model. Look at `TTS.tts.layers.feed_forward.encoder` for more details.
|
||||
Defaults to `residual_conv_bn`.
|
||||
encoder_params (dict):
|
||||
Parameters used to define the encoder network. Look at `TTS.tts.layers.feed_forward.encoder` for more details.
|
||||
Defaults to `{"kernel_size": 4, "dilations": [1, 2, 4, 1, 2, 4, 1, 2, 4, 1, 2, 4, 1], "num_conv_blocks": 2, "num_res_blocks": 13}`
|
||||
decoder_type (str):
|
||||
Type of the decoder used by the model. Look at `TTS.tts.layers.feed_forward.decoder` for more details.
|
||||
Defaults to `residual_conv_bn`.
|
||||
decoder_params (dict):
|
||||
Parameters used to define the decoder network. Look at `TTS.tts.layers.feed_forward.decoder` for more details.
|
||||
Defaults to `{"kernel_size": 4, "dilations": [1, 2, 4, 8, 1, 2, 4, 8, 1, 2, 4, 8, 1, 2, 4, 8, 1], "num_conv_blocks": 2, "num_res_blocks": 17}`
|
||||
hidden_channels_encoder (int):
|
||||
Number of base hidden channels used by the encoder network. It defines the input and the output channel sizes,
|
||||
and for some encoder types internal hidden channels sizes too. Defaults to 192.
|
||||
hidden_channels_decoder (int):
|
||||
Number of base hidden channels used by the decoder WaveNet network. Defaults to 192 as in the original work.
|
||||
hidden_channels_duration_predictor (int):
|
||||
Number of layer channels of the duration predictor network. Defaults to 256 as in the original work.
|
||||
model_args (Coqpit):
|
||||
Model class arguments. Check `SpeedySpeechArgs` for more details. Defaults to `SpeedySpeechArgs()`.
|
||||
data_dep_init_steps (int):
|
||||
Number of steps used for computing normalization parameters at the beginning of the training. GlowTTS uses
|
||||
Activation Normalization that pre-computes normalization stats at the beginning and use the same values
|
||||
|
@ -46,9 +26,9 @@ class SpeedySpeechConfig(BaseTTSConfig):
|
|||
use_speaker_embedding (bool):
|
||||
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
|
||||
in the multi-speaker mode. Defaults to False.
|
||||
use_external_speaker_embedding_file (bool):
|
||||
use_d_vector_file (bool):
|
||||
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
|
||||
external_speaker_embedding_file (str):
|
||||
d_vector_file (str):
|
||||
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
||||
noam_schedule (bool):
|
||||
enable / disable the use of Noam LR scheduler. Defaults to False.
|
||||
|
@ -72,37 +52,19 @@ class SpeedySpeechConfig(BaseTTSConfig):
|
|||
|
||||
model: str = "speedy_speech"
|
||||
# model specific params
|
||||
positional_encoding: bool = True
|
||||
hidden_channels: int = 128
|
||||
encoder_type: str = "residual_conv_bn"
|
||||
encoder_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"kernel_size": 4,
|
||||
"dilations": [1, 2, 4, 1, 2, 4, 1, 2, 4, 1, 2, 4, 1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 13,
|
||||
}
|
||||
)
|
||||
decoder_type: str = "residual_conv_bn"
|
||||
decoder_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"kernel_size": 4,
|
||||
"dilations": [1, 2, 4, 8, 1, 2, 4, 8, 1, 2, 4, 8, 1, 2, 4, 8, 1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 17,
|
||||
}
|
||||
)
|
||||
model_args: SpeedySpeechArgs = field(default_factory=SpeedySpeechArgs)
|
||||
|
||||
# multi-speaker settings
|
||||
use_speaker_embedding: bool = False
|
||||
use_external_speaker_embedding_file: bool = False
|
||||
external_speaker_embedding_file: str = False
|
||||
use_d_vector_file: bool = False
|
||||
d_vector_file: str = False
|
||||
|
||||
# optimizer parameters
|
||||
noam_schedule: bool = False
|
||||
warmup_steps: int = 4000
|
||||
optimizer: str = "RAdam"
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
|
||||
lr_scheduler: str = None
|
||||
lr_scheduler_params: dict = None
|
||||
lr: float = 1e-4
|
||||
wd: float = 1e-6
|
||||
grad_clip: float = 5.0
|
||||
|
||||
# loss params
|
||||
|
@ -114,3 +76,14 @@ class SpeedySpeechConfig(BaseTTSConfig):
|
|||
min_seq_len: int = 13
|
||||
max_seq_len: int = 200
|
||||
r: int = 1 # DO NOT CHANGE
|
||||
|
||||
# testing
|
||||
test_sentences: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963.",
|
||||
]
|
||||
)
|
||||
|
|
|
@ -12,107 +12,10 @@ class Tacotron2Config(TacotronConfig):
|
|||
>>> from TTS.tts.configs import Tacotron2Config
|
||||
>>> config = Tacotron2Config()
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used to select the right model class to initilize. Defaults to `Tacotron2`.
|
||||
use_gst (bool):
|
||||
enable / disable the use of Global Style Token modules. Defaults to False.
|
||||
gst (GSTConfig):
|
||||
Instance of `GSTConfig` class.
|
||||
gst_style_input (str):
|
||||
Path to the wav file used at inference to set the speech style through GST. If `GST` is enabled and
|
||||
this is not defined, the model uses a zero vector as an input. Defaults to None.
|
||||
r (int):
|
||||
Number of output frames that the decoder computed per iteration. Larger values makes training and inference
|
||||
faster but reduces the quality of the output frames. This needs to be tuned considering your own needs.
|
||||
Defaults to 1.
|
||||
gradual_trainin (List[List]):
|
||||
Parameters for the gradual training schedule. It is in the form `[[a, b, c], [d ,e ,f] ..]` where `a` is
|
||||
the step number to start using the rest of the values, `b` is the `r` value and `c` is the batch size.
|
||||
If sets None, no gradual training is used. Defaults to None.
|
||||
memory_size (int):
|
||||
Defines the number of previous frames used by the Prenet. If set to < 0, then it uses only the last frame.
|
||||
Defaults to -1.
|
||||
prenet_type (str):
|
||||
`original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the
|
||||
Prenet. Defaults to `original`.
|
||||
prenet_dropout (bool):
|
||||
enables / disables the use of dropout in the Prenet. Defaults to True.
|
||||
prenet_dropout_at_inference (bool):
|
||||
enable / disable the use of dropout in the Prenet at the inference time. Defaults to False.
|
||||
stopnet (bool):
|
||||
enable /disable the Stopnet that predicts the end of the decoder sequence. Defaults to True.
|
||||
stopnet_pos_weight (float):
|
||||
Weight that is applied to over-weight positive instances in the Stopnet loss. Use larger values with
|
||||
datasets with longer sentences. Defaults to 10.
|
||||
separate_stopnet (bool):
|
||||
Use a distinct Stopnet which is trained separately from the rest of the model. Defaults to True.
|
||||
attention_type (str):
|
||||
attention type. Check ```TTS.tts.layers.attentions.init_attn```. Defaults to 'original'.
|
||||
attention_heads (int):
|
||||
Number of attention heads for GMM attention. Defaults to 5.
|
||||
windowing (bool):
|
||||
It especially useful at inference to keep attention alignment diagonal. Defaults to False.
|
||||
use_forward_attn (bool):
|
||||
It is only valid if ```attn_type``` is ```original```. Defaults to False.
|
||||
forward_attn_mask (bool):
|
||||
enable/disable extra masking over forward attention. It is useful at inference to prevent
|
||||
possible attention failures. Defaults to False.
|
||||
transition_agent (bool):
|
||||
enable/disable transition agent in forward attention. Defaults to False.
|
||||
location_attn (bool):
|
||||
enable/disable location sensitive attention as in the original Tacotron2 paper.
|
||||
It is only valid if ```attn_type``` is ```original```. Defaults to True.
|
||||
bidirectional_decoder (bool):
|
||||
enable/disable bidirectional decoding. Defaults to False.
|
||||
double_decoder_consistency (bool):
|
||||
enable/disable double decoder consistency. Defaults to False.
|
||||
ddc_r (int):
|
||||
reduction rate used by the coarse decoder when `double_decoder_consistency` is in use. Set this
|
||||
as a multiple of the `r` value. Defaults to 6.
|
||||
use_speaker_embedding (bool):
|
||||
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
|
||||
in the multi-speaker mode. Defaults to False.
|
||||
use_external_speaker_embedding_file (bool):
|
||||
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
|
||||
external_speaker_embedding_file (str):
|
||||
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
||||
noam_schedule (bool):
|
||||
enable / disable the use of Noam LR scheduler. Defaults to False.
|
||||
warmup_steps (int):
|
||||
Number of warm-up steps for the Noam scheduler. Defaults 4000.
|
||||
lr (float):
|
||||
Initial learning rate. Defaults to `1e-4`.
|
||||
wd (float):
|
||||
Weight decay coefficient. Defaults to `1e-6`.
|
||||
grad_clip (float):
|
||||
Gradient clipping threshold. Defaults to `5`.
|
||||
seq_len_notm (bool):
|
||||
enable / disable the sequnce length normalization in the loss functions. If set True, loss of a sample
|
||||
is divided by the sequence length. Defaults to False.
|
||||
loss_masking (bool):
|
||||
enable / disable masking the paddings of the samples in loss computation. Defaults to True.
|
||||
decoder_loss_alpha (float):
|
||||
Weight for the decoder loss of the Tacotron model. If set less than or equal to zero, it disables the
|
||||
corresponding loss function. Defaults to 0.25
|
||||
postnet_loss_alpha (float):
|
||||
Weight for the postnet loss of the Tacotron model. If set less than or equal to zero, it disables the
|
||||
corresponding loss function. Defaults to 0.25
|
||||
postnet_diff_spec_alpha (float):
|
||||
Weight for the postnet differential loss of the Tacotron model. If set less than or equal to zero, it disables the
|
||||
corresponding loss function. Defaults to 0.25
|
||||
decoder_diff_spec_alpha (float):
|
||||
Weight for the decoder differential loss of the Tacotron model. If set less than or equal to zero, it disables the
|
||||
corresponding loss function. Defaults to 0.25
|
||||
decoder_ssim_alpha (float):
|
||||
Weight for the decoder SSIM loss of the Tacotron model. If set less than or equal to zero, it disables the
|
||||
corresponding loss function. Defaults to 0.25
|
||||
postnet_ssim_alpha (float):
|
||||
Weight for the postnet SSIM loss of the Tacotron model. If set less than or equal to zero, it disables the
|
||||
corresponding loss function. Defaults to 0.25
|
||||
ga_alpha (float):
|
||||
Weight for the guided attention loss. If set less than or equal to zero, it disables the corresponding loss
|
||||
function. Defaults to 5.
|
||||
Check `TacotronConfig` for argument descriptions.
|
||||
"""
|
||||
|
||||
model: str = "tacotron2"
|
||||
out_channels: int = 80
|
||||
encoder_in_features: int = 512
|
||||
decoder_in_features: int = 512
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
from TTS.tts.configs.shared_configs import BaseTTSConfig, GSTConfig
|
||||
|
@ -23,6 +23,10 @@ class TacotronConfig(BaseTTSConfig):
|
|||
gst_style_input (str):
|
||||
Path to the wav file used at inference to set the speech style through GST. If `GST` is enabled and
|
||||
this is not defined, the model uses a zero vector as an input. Defaults to None.
|
||||
num_chars (int):
|
||||
Number of characters used by the model. It must be defined before initializing the model. Defaults to None.
|
||||
num_speakers (int):
|
||||
Number of speakers for multi-speaker models. Defaults to 1.
|
||||
r (int):
|
||||
Initial number of output frames that the decoder computed per iteration. Larger values makes training and inference
|
||||
faster but reduces the quality of the output frames. This must be equal to the largest `r` value used in
|
||||
|
@ -46,6 +50,14 @@ class TacotronConfig(BaseTTSConfig):
|
|||
stopnet_pos_weight (float):
|
||||
Weight that is applied to over-weight positive instances in the Stopnet loss. Use larger values with
|
||||
datasets with longer sentences. Defaults to 10.
|
||||
max_decoder_steps (int):
|
||||
Max number of steps allowed for the decoder. Defaults to 50.
|
||||
encoder_in_features (int):
|
||||
Channels of encoder input and character embedding tensors. Defaults to 256.
|
||||
decoder_in_features (int):
|
||||
Channels of decoder input and encoder output tensors. Defaults to 256.
|
||||
out_channels (int):
|
||||
Channels of the final model output. It must match the spectragram size. Defaults to 80.
|
||||
separate_stopnet (bool):
|
||||
Use a distinct Stopnet which is trained separately from the rest of the model. Defaults to True.
|
||||
attention_type (str):
|
||||
|
@ -74,14 +86,20 @@ class TacotronConfig(BaseTTSConfig):
|
|||
use_speaker_embedding (bool):
|
||||
enable / disable using speaker embeddings for multi-speaker models. If set True, the model is
|
||||
in the multi-speaker mode. Defaults to False.
|
||||
use_external_speaker_embedding_file (bool):
|
||||
use_d_vector_file (bool):
|
||||
enable /disable using external speaker embeddings in place of the learned embeddings. Defaults to False.
|
||||
external_speaker_embedding_file (str):
|
||||
d_vector_file (str):
|
||||
Path to the file including pre-computed speaker embeddings. Defaults to None.
|
||||
noam_schedule (bool):
|
||||
enable / disable the use of Noam LR scheduler. Defaults to False.
|
||||
warmup_steps (int):
|
||||
Number of warm-up steps for the Noam scheduler. Defaults 4000.
|
||||
optimizer (str):
|
||||
Optimizer used for the training. Set one from `torch.optim.Optimizer` or `TTS.utils.training`.
|
||||
Defaults to `RAdam`.
|
||||
optimizer_params (dict):
|
||||
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
||||
lr_scheduler (str):
|
||||
Learning rate scheduler for the training. Use one from `torch.optim.Scheduler` schedulers or
|
||||
`TTS.utils.training`. Defaults to `NoamLR`.
|
||||
lr_scheduler_params (dict):
|
||||
Parameters for the generator learning rate scheduler. Defaults to `{"warmup": 4000}`.
|
||||
lr (float):
|
||||
Initial learning rate. Defaults to `1e-4`.
|
||||
wd (float):
|
||||
|
@ -103,6 +121,7 @@ class TacotronConfig(BaseTTSConfig):
|
|||
Weight for the postnet differential loss of the Tacotron model. If set less than or equal to zero, it disables the
|
||||
corresponding loss function. Defaults to 0.25
|
||||
decoder_diff_spec_alpha (float):
|
||||
|
||||
Weight for the decoder differential loss of the Tacotron model. If set less than or equal to zero, it disables the
|
||||
corresponding loss function. Defaults to 0.25
|
||||
decoder_ssim_alpha (float):
|
||||
|
@ -117,10 +136,14 @@ class TacotronConfig(BaseTTSConfig):
|
|||
"""
|
||||
|
||||
model: str = "tacotron"
|
||||
# model_params: TacotronArgs = field(default_factory=lambda: TacotronArgs())
|
||||
use_gst: bool = False
|
||||
gst: GSTConfig = None
|
||||
gst_style_input: str = None
|
||||
|
||||
# model specific params
|
||||
num_speakers: int = 1
|
||||
num_chars: int = 0
|
||||
r: int = 2
|
||||
gradual_training: List[List[int]] = None
|
||||
memory_size: int = -1
|
||||
|
@ -130,11 +153,17 @@ class TacotronConfig(BaseTTSConfig):
|
|||
stopnet: bool = True
|
||||
separate_stopnet: bool = True
|
||||
stopnet_pos_weight: float = 10.0
|
||||
max_decoder_steps: int = 500
|
||||
encoder_in_features: int = 256
|
||||
decoder_in_features: int = 256
|
||||
decoder_output_dim: int = 80
|
||||
out_channels: int = 513
|
||||
|
||||
# attention layers
|
||||
attention_type: str = "original"
|
||||
attention_heads: int = None
|
||||
attention_norm: str = "sigmoid"
|
||||
attention_win: bool = False
|
||||
windowing: bool = False
|
||||
use_forward_attn: bool = False
|
||||
forward_attn_mask: bool = False
|
||||
|
@ -148,14 +177,17 @@ class TacotronConfig(BaseTTSConfig):
|
|||
|
||||
# multi-speaker settings
|
||||
use_speaker_embedding: bool = False
|
||||
use_external_speaker_embedding_file: bool = False
|
||||
external_speaker_embedding_file: str = False
|
||||
speaker_embedding_dim: int = 512
|
||||
use_d_vector_file: bool = False
|
||||
d_vector_file: str = False
|
||||
d_vector_dim: int = None
|
||||
|
||||
# optimizer parameters
|
||||
noam_schedule: bool = False
|
||||
warmup_steps: int = 4000
|
||||
optimizer: str = "RAdam"
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.9, 0.998], "weight_decay": 1e-6})
|
||||
lr_scheduler: str = "NoamLR"
|
||||
lr_scheduler_params: dict = field(default_factory=lambda: {"warmup_steps": 4000})
|
||||
lr: float = 1e-4
|
||||
wd: float = 1e-6
|
||||
grad_clip: float = 5.0
|
||||
seq_len_norm: bool = False
|
||||
loss_masking: bool = True
|
||||
|
@ -169,8 +201,25 @@ class TacotronConfig(BaseTTSConfig):
|
|||
postnet_ssim_alpha: float = 0.25
|
||||
ga_alpha: float = 5.0
|
||||
|
||||
# testing
|
||||
test_sentences: List[str] = field(
|
||||
default_factory=lambda: [
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963.",
|
||||
]
|
||||
)
|
||||
|
||||
def check_values(self):
|
||||
if self.gradual_training:
|
||||
assert (
|
||||
self.gradual_training[0][1] == self.r
|
||||
), f"[!] the first scheduled gradual training `r` must be equal to the model's `r` value. {self.gradual_training[0][1]} vs {self.r}"
|
||||
if self.model == "tacotron" and self.audio is not None:
|
||||
assert self.out_channels == (
|
||||
self.audio.fft_size // 2 + 1
|
||||
), f"{self.out_channels} vs {self.audio.fft_size // 2 + 1}"
|
||||
if self.model == "tacotron2" and self.audio is not None:
|
||||
assert self.out_channels == self.audio.num_mels
|
||||
|
|
|
@ -2,6 +2,7 @@ import collections
|
|||
import os
|
||||
import random
|
||||
from multiprocessing import Pool
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -10,49 +11,82 @@ from torch.utils.data import Dataset
|
|||
|
||||
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
||||
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
class TTSDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
outputs_per_step,
|
||||
text_cleaner,
|
||||
compute_linear_spec,
|
||||
ap,
|
||||
meta_data,
|
||||
tp=None,
|
||||
add_blank=False,
|
||||
batch_group_size=0,
|
||||
min_seq_len=0,
|
||||
max_seq_len=float("inf"),
|
||||
use_phonemes=False,
|
||||
phoneme_cache_path=None,
|
||||
phoneme_language="en-us",
|
||||
enable_eos_bos=False,
|
||||
speaker_mapping=None,
|
||||
use_noise_augment=False,
|
||||
verbose=False,
|
||||
outputs_per_step: int,
|
||||
text_cleaner: list,
|
||||
compute_linear_spec: bool,
|
||||
ap: AudioProcessor,
|
||||
meta_data: List[List],
|
||||
characters: Dict = None,
|
||||
add_blank: bool = False,
|
||||
batch_group_size: int = 0,
|
||||
min_seq_len: int = 0,
|
||||
max_seq_len: int = float("inf"),
|
||||
use_phonemes: bool = False,
|
||||
phoneme_cache_path: str = None,
|
||||
phoneme_language: str = "en-us",
|
||||
enable_eos_bos: bool = False,
|
||||
speaker_id_mapping: Dict = None,
|
||||
d_vector_mapping: Dict = None,
|
||||
use_noise_augment: bool = False,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
"""Generic 📂 data loader for `tts` models. It is configurable for different outputs and needs.
|
||||
|
||||
If you need something different, you can either override or create a new class as the dataset is
|
||||
initialized by the model.
|
||||
|
||||
Args:
|
||||
outputs_per_step (int): number of time frames predicted per step.
|
||||
text_cleaner (str): text cleaner used for the dataset.
|
||||
outputs_per_step (int): Number of time frames predicted per step.
|
||||
|
||||
text_cleaner (list): List of text cleaners to clean the input text before converting to sequence IDs.
|
||||
|
||||
compute_linear_spec (bool): compute linear spectrogram if True.
|
||||
ap (TTS.tts.utils.AudioProcessor): audio processor object.
|
||||
meta_data (list): list of dataset instances.
|
||||
tp (dict): dict of custom text characters used for converting texts to sequences.
|
||||
batch_group_size (int): (0) range of batch randomization after sorting
|
||||
sequences by length.
|
||||
min_seq_len (int): (0) minimum sequence length to be processed
|
||||
by the loader.
|
||||
max_seq_len (int): (float("inf")) maximum sequence length.
|
||||
use_phonemes (bool): (true) if true, text converted to phonemes.
|
||||
phoneme_cache_path (str): path to cache phoneme features.
|
||||
phoneme_language (str): one the languages from
|
||||
https://github.com/bootphon/phonemizer#languages
|
||||
enable_eos_bos (bool): enable end of sentence and beginning of sentences characters.
|
||||
use_noise_augment (bool): enable adding random noise to wav for augmentation.
|
||||
verbose (bool): print diagnostic information.
|
||||
|
||||
ap (TTS.tts.utils.AudioProcessor): Audio processor object.
|
||||
|
||||
meta_data (list): List of dataset instances.
|
||||
|
||||
characters (dict): `dict` of custom text characters used for converting texts to sequences.
|
||||
|
||||
add_blank (bool): Add a special `blank` character after every other character. It helps some
|
||||
models achieve better results. Defaults to false.
|
||||
|
||||
batch_group_size (int): Range of batch randomization after sorting
|
||||
sequences by length. It shuffles each batch with bucketing to gather similar lenght sequences in a
|
||||
batch. Set 0 to disable. Defaults to 0.
|
||||
|
||||
min_seq_len (int): Minimum input sequence length to be processed
|
||||
by the loader. Filter out input sequences that are shorter than this. Some models have a
|
||||
minimum input length due to its architecture. Defaults to 0.
|
||||
|
||||
max_seq_len (int): Maximum input sequence length. Filter out input sequences that are longer than this.
|
||||
It helps for controlling the VRAM usage against long input sequences. Especially models with
|
||||
RNN layers are sensitive to input length. Defaults to `Inf`.
|
||||
|
||||
use_phonemes (bool): If true, input text converted to phonemes. Defaults to false.
|
||||
|
||||
phoneme_cache_path (str): Path to cache phoneme features. It writes computed phonemes to files to use in
|
||||
the coming iterations. Defaults to None.
|
||||
|
||||
phoneme_language (str): One the languages from supported by the phonemizer interface. Defaults to `en-us`.
|
||||
|
||||
enable_eos_bos (bool): Enable the `end of sentence` and the `beginning of sentences characters`. Defaults
|
||||
to False.
|
||||
|
||||
speaker_id_mapping (dict): Mapping of speaker names to IDs used to compute embedding vectors by the
|
||||
embedding layer. Defaults to None.
|
||||
|
||||
d_vector_mapping (dict): Mapping of wav files to computed d-vectors. Defaults to None.
|
||||
|
||||
use_noise_augment (bool): Enable adding random noise to wav for augmentation. Defaults to False.
|
||||
|
||||
verbose (bool): Print diagnostic information. Defaults to false.
|
||||
"""
|
||||
super().__init__()
|
||||
self.batch_group_size = batch_group_size
|
||||
|
@ -64,13 +98,14 @@ class MyDataset(Dataset):
|
|||
self.min_seq_len = min_seq_len
|
||||
self.max_seq_len = max_seq_len
|
||||
self.ap = ap
|
||||
self.tp = tp
|
||||
self.characters = characters
|
||||
self.add_blank = add_blank
|
||||
self.use_phonemes = use_phonemes
|
||||
self.phoneme_cache_path = phoneme_cache_path
|
||||
self.phoneme_language = phoneme_language
|
||||
self.enable_eos_bos = enable_eos_bos
|
||||
self.speaker_mapping = speaker_mapping
|
||||
self.speaker_id_mapping = speaker_id_mapping
|
||||
self.d_vector_mapping = d_vector_mapping
|
||||
self.use_noise_augment = use_noise_augment
|
||||
self.verbose = verbose
|
||||
self.input_seq_computed = False
|
||||
|
@ -93,13 +128,13 @@ class MyDataset(Dataset):
|
|||
return data
|
||||
|
||||
@staticmethod
|
||||
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, tp, add_blank):
|
||||
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, characters, add_blank):
|
||||
"""generate a phoneme sequence from text.
|
||||
since the usage is for subsequent caching, we never add bos and
|
||||
eos chars here. Instead we add those dynamically later; based on the
|
||||
config option."""
|
||||
phonemes = phoneme_to_sequence(
|
||||
text, [cleaners], language=language, enable_eos_bos=False, tp=tp, add_blank=add_blank
|
||||
text, [cleaners], language=language, enable_eos_bos=False, tp=characters, add_blank=add_blank
|
||||
)
|
||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
||||
np.save(cache_path, phonemes)
|
||||
|
@ -107,7 +142,7 @@ class MyDataset(Dataset):
|
|||
|
||||
@staticmethod
|
||||
def _load_or_generate_phoneme_sequence(
|
||||
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, tp, add_blank
|
||||
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, characters, add_blank
|
||||
):
|
||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||
|
||||
|
@ -117,16 +152,16 @@ class MyDataset(Dataset):
|
|||
try:
|
||||
phonemes = np.load(cache_path)
|
||||
except FileNotFoundError:
|
||||
phonemes = MyDataset._generate_and_cache_phoneme_sequence(
|
||||
text, cache_path, cleaners, language, tp, add_blank
|
||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
||||
text, cache_path, cleaners, language, characters, add_blank
|
||||
)
|
||||
except (ValueError, IOError):
|
||||
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
|
||||
phonemes = MyDataset._generate_and_cache_phoneme_sequence(
|
||||
text, cache_path, cleaners, language, tp, add_blank
|
||||
phonemes = TTSDataset._generate_and_cache_phoneme_sequence(
|
||||
text, cache_path, cleaners, language, characters, add_blank
|
||||
)
|
||||
if enable_eos_bos:
|
||||
phonemes = pad_with_eos_bos(phonemes, tp=tp)
|
||||
phonemes = pad_with_eos_bos(phonemes, tp=characters)
|
||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
||||
return phonemes
|
||||
|
||||
|
@ -154,13 +189,14 @@ class MyDataset(Dataset):
|
|||
self.enable_eos_bos,
|
||||
self.cleaners,
|
||||
self.phoneme_language,
|
||||
self.tp,
|
||||
self.characters,
|
||||
self.add_blank,
|
||||
)
|
||||
|
||||
else:
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners], tp=self.tp, add_blank=self.add_blank), dtype=np.int32
|
||||
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
|
||||
dtype=np.int32,
|
||||
)
|
||||
|
||||
assert text.size > 0, self.items[idx][1]
|
||||
|
@ -190,7 +226,7 @@ class MyDataset(Dataset):
|
|||
item = args[0]
|
||||
func_args = args[1]
|
||||
text, wav_file, *_ = item
|
||||
phonemes = MyDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
|
||||
phonemes = TTSDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
|
||||
return phonemes
|
||||
|
||||
def compute_input_seq(self, num_workers=0):
|
||||
|
@ -202,7 +238,8 @@ class MyDataset(Dataset):
|
|||
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
||||
text, *_ = item
|
||||
sequence = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners], tp=self.tp, add_blank=self.add_blank), dtype=np.int32
|
||||
text_to_sequence(text, [self.cleaners], tp=self.characters, add_blank=self.add_blank),
|
||||
dtype=np.int32,
|
||||
)
|
||||
self.items[idx][0] = sequence
|
||||
|
||||
|
@ -212,7 +249,7 @@ class MyDataset(Dataset):
|
|||
self.enable_eos_bos,
|
||||
self.cleaners,
|
||||
self.phoneme_language,
|
||||
self.tp,
|
||||
self.characters,
|
||||
self.add_blank,
|
||||
]
|
||||
if self.verbose:
|
||||
|
@ -225,7 +262,7 @@ class MyDataset(Dataset):
|
|||
with Pool(num_workers) as p:
|
||||
phonemes = list(
|
||||
tqdm.tqdm(
|
||||
p.imap(MyDataset._phoneme_worker, [[item, func_args] for item in self.items]),
|
||||
p.imap(TTSDataset._phoneme_worker, [[item, func_args] for item in self.items]),
|
||||
total=len(self.items),
|
||||
)
|
||||
)
|
||||
|
@ -282,7 +319,7 @@ class MyDataset(Dataset):
|
|||
"""
|
||||
|
||||
# Puts each data field into a tensor with outer dimension batch size
|
||||
if isinstance(batch[0], collections.Mapping):
|
||||
if isinstance(batch[0], collections.abc.Mapping):
|
||||
|
||||
text_lenghts = np.array([len(d["text"]) for d in batch])
|
||||
|
||||
|
@ -293,13 +330,18 @@ class MyDataset(Dataset):
|
|||
item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing]
|
||||
text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
|
||||
|
||||
speaker_name = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
|
||||
# get speaker embeddings
|
||||
if self.speaker_mapping is not None:
|
||||
speaker_names = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
|
||||
# get pre-computed d-vectors
|
||||
if self.d_vector_mapping is not None:
|
||||
wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing]
|
||||
speaker_embedding = [self.speaker_mapping[w]["embedding"] for w in wav_files_names]
|
||||
d_vectors = [self.d_vector_mapping[w]["embedding"] for w in wav_files_names]
|
||||
else:
|
||||
speaker_embedding = None
|
||||
d_vectors = None
|
||||
# get numerical speaker ids from speaker names
|
||||
if self.speaker_id_mapping:
|
||||
speaker_ids = [self.speaker_id_mapping[sn] for sn in speaker_names]
|
||||
else:
|
||||
speaker_ids = None
|
||||
# compute features
|
||||
mel = [self.ap.melspectrogram(w).astype("float32") for w in wav]
|
||||
|
||||
|
@ -327,8 +369,11 @@ class MyDataset(Dataset):
|
|||
mel_lengths = torch.LongTensor(mel_lengths)
|
||||
stop_targets = torch.FloatTensor(stop_targets)
|
||||
|
||||
if speaker_embedding is not None:
|
||||
speaker_embedding = torch.FloatTensor(speaker_embedding)
|
||||
if d_vectors is not None:
|
||||
d_vectors = torch.FloatTensor(d_vectors)
|
||||
|
||||
if speaker_ids is not None:
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
|
||||
# compute linear spectrogram
|
||||
if self.compute_linear_spec:
|
||||
|
@ -355,13 +400,14 @@ class MyDataset(Dataset):
|
|||
return (
|
||||
text,
|
||||
text_lenghts,
|
||||
speaker_name,
|
||||
speaker_names,
|
||||
linear,
|
||||
mel,
|
||||
mel_lengths,
|
||||
stop_targets,
|
||||
item_idxs,
|
||||
speaker_embedding,
|
||||
d_vectors,
|
||||
speaker_ids,
|
||||
attns,
|
||||
)
|
||||
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
import sys
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from TTS.tts.datasets.formatters import *
|
||||
from TTS.tts.datasets.TTSDataset import TTSDataset
|
||||
|
||||
|
||||
def split_dataset(items):
|
||||
speakers = [item[-1] for item in items]
|
||||
is_multi_speaker = len(set(speakers)) > 1
|
||||
eval_split_size = min(500, int(len(items) * 0.01))
|
||||
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
|
||||
np.random.seed(0)
|
||||
np.random.shuffle(items)
|
||||
if is_multi_speaker:
|
||||
items_eval = []
|
||||
speakers = [item[-1] for item in items]
|
||||
speaker_counter = Counter(speakers)
|
||||
while len(items_eval) < eval_split_size:
|
||||
item_idx = np.random.randint(0, len(items))
|
||||
speaker_to_be_removed = items[item_idx][-1]
|
||||
if speaker_counter[speaker_to_be_removed] > 1:
|
||||
items_eval.append(items[item_idx])
|
||||
speaker_counter[speaker_to_be_removed] -= 1
|
||||
del items[item_idx]
|
||||
return items_eval, items
|
||||
return items[:eval_split_size], items[eval_split_size:]
|
||||
|
||||
|
||||
def load_meta_data(datasets, eval_split=True, ignore_generated_eval=False):
|
||||
meta_data_train_all = []
|
||||
meta_data_eval_all = [] if eval_split else None
|
||||
for dataset in datasets:
|
||||
name = dataset["name"]
|
||||
root_path = dataset["path"]
|
||||
meta_file_train = dataset["meta_file_train"]
|
||||
meta_file_val = dataset["meta_file_val"]
|
||||
# setup the right data processor
|
||||
preprocessor = _get_preprocessor_by_name(name)
|
||||
# load train set
|
||||
meta_data_train = preprocessor(root_path, meta_file_train)
|
||||
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
||||
# load evaluation split if set
|
||||
if eval_split:
|
||||
if meta_file_val:
|
||||
meta_data_eval = preprocessor(root_path, meta_file_val)
|
||||
meta_data_eval_all += meta_data_eval
|
||||
elif not ignore_generated_eval:
|
||||
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
||||
meta_data_eval_all += meta_data_eval
|
||||
|
||||
meta_data_train_all += meta_data_train
|
||||
# load attention masks for duration predictor training
|
||||
if dataset.meta_file_attn_mask:
|
||||
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
|
||||
for idx, ins in enumerate(meta_data_train_all):
|
||||
attn_file = meta_data[ins[1]].strip()
|
||||
meta_data_train_all[idx].append(attn_file)
|
||||
if meta_data_eval_all:
|
||||
for idx, ins in enumerate(meta_data_eval_all):
|
||||
attn_file = meta_data[ins[1]].strip()
|
||||
meta_data_eval_all[idx].append(attn_file)
|
||||
return meta_data_train_all, meta_data_eval_all
|
||||
|
||||
|
||||
def load_attention_mask_meta_data(metafile_path):
|
||||
"""Load meta data file created by compute_attention_masks.py"""
|
||||
with open(metafile_path, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
meta_data = []
|
||||
for line in lines:
|
||||
wav_file, attn_file = line.split("|")
|
||||
meta_data.append([wav_file, attn_file])
|
||||
return meta_data
|
||||
|
||||
|
||||
def _get_preprocessor_by_name(name):
|
||||
"""Returns the respective preprocessing function."""
|
||||
thismodule = sys.modules[__name__]
|
||||
return getattr(thismodule, name.lower())
|
|
@ -1,96 +1,12 @@
|
|||
import os
|
||||
import re
|
||||
import sys
|
||||
import xml.etree.ElementTree as ET
|
||||
from collections import Counter
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
####################
|
||||
# UTILITIES
|
||||
####################
|
||||
|
||||
|
||||
def split_dataset(items):
|
||||
speakers = [item[-1] for item in items]
|
||||
is_multi_speaker = len(set(speakers)) > 1
|
||||
eval_split_size = min(500, int(len(items) * 0.01))
|
||||
assert eval_split_size > 0, " [!] You do not have enough samples to train. You need at least 100 samples."
|
||||
np.random.seed(0)
|
||||
np.random.shuffle(items)
|
||||
if is_multi_speaker:
|
||||
items_eval = []
|
||||
speakers = [item[-1] for item in items]
|
||||
speaker_counter = Counter(speakers)
|
||||
while len(items_eval) < eval_split_size:
|
||||
item_idx = np.random.randint(0, len(items))
|
||||
speaker_to_be_removed = items[item_idx][-1]
|
||||
if speaker_counter[speaker_to_be_removed] > 1:
|
||||
items_eval.append(items[item_idx])
|
||||
speaker_counter[speaker_to_be_removed] -= 1
|
||||
del items[item_idx]
|
||||
return items_eval, items
|
||||
return items[:eval_split_size], items[eval_split_size:]
|
||||
|
||||
|
||||
def load_meta_data(datasets, eval_split=True, ignore_generated_eval=False):
|
||||
meta_data_train_all = []
|
||||
meta_data_eval_all = [] if eval_split else None
|
||||
for dataset in datasets:
|
||||
name = dataset["name"]
|
||||
root_path = dataset["path"]
|
||||
meta_file_train = dataset["meta_file_train"]
|
||||
meta_file_val = dataset["meta_file_val"]
|
||||
# setup the right data processor
|
||||
preprocessor = get_preprocessor_by_name(name)
|
||||
# load train set
|
||||
meta_data_train = preprocessor(root_path, meta_file_train)
|
||||
print(f" | > Found {len(meta_data_train)} files in {Path(root_path).resolve()}")
|
||||
# load evaluation split if set
|
||||
if eval_split:
|
||||
if meta_file_val:
|
||||
meta_data_eval = preprocessor(root_path, meta_file_val)
|
||||
meta_data_eval_all += meta_data_eval
|
||||
elif not ignore_generated_eval:
|
||||
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
|
||||
meta_data_eval_all += meta_data_eval
|
||||
|
||||
meta_data_train_all += meta_data_train
|
||||
# load attention masks for duration predictor training
|
||||
if dataset.meta_file_attn_mask:
|
||||
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
|
||||
for idx, ins in enumerate(meta_data_train_all):
|
||||
attn_file = meta_data[ins[1]].strip()
|
||||
meta_data_train_all[idx].append(attn_file)
|
||||
if meta_data_eval_all:
|
||||
for idx, ins in enumerate(meta_data_eval_all):
|
||||
attn_file = meta_data[ins[1]].strip()
|
||||
meta_data_eval_all[idx].append(attn_file)
|
||||
return meta_data_train_all, meta_data_eval_all
|
||||
|
||||
|
||||
def load_attention_mask_meta_data(metafile_path):
|
||||
"""Load meta data file created by compute_attention_masks.py"""
|
||||
with open(metafile_path, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
meta_data = []
|
||||
for line in lines:
|
||||
wav_file, attn_file = line.split("|")
|
||||
meta_data.append([wav_file, attn_file])
|
||||
return meta_data
|
||||
|
||||
|
||||
def get_preprocessor_by_name(name):
|
||||
"""Returns the respective preprocessing function."""
|
||||
thismodule = sys.modules[__name__]
|
||||
return getattr(thismodule, name.lower())
|
||||
|
||||
|
||||
########################
|
||||
# DATASETS
|
||||
########################
|
||||
|
@ -191,6 +107,20 @@ def ljspeech(root_path, meta_file):
|
|||
return items
|
||||
|
||||
|
||||
def ljspeech_test(root_path, meta_file):
|
||||
"""Normalizes the LJSpeech meta data file for TTS testing
|
||||
https://keithito.com/LJ-Speech-Dataset/"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for idx, line in enumerate(ttf):
|
||||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[1]
|
||||
items.append([text, wav_file, f"ljspeech-{idx}"])
|
||||
return items
|
||||
|
||||
|
||||
def sam_accenture(root_path, meta_file):
|
||||
"""Normalizes the sam-accenture meta data file to TTS format
|
||||
https://github.com/Sam-Accenture-Non-Binary-Voice/non-binary-voice-files"""
|
|
@ -0,0 +1,15 @@
|
|||
from TTS.tts.layers.losses import *
|
||||
|
||||
|
||||
def setup_loss(config):
|
||||
if config.model.lower() in ["tacotron", "tacotron2"]:
|
||||
model = TacotronLoss(config)
|
||||
elif config.model.lower() == "glow_tts":
|
||||
model = GlowTTSLoss()
|
||||
elif config.model.lower() == "speedy_speech":
|
||||
model = SpeedySpeechLoss(config)
|
||||
elif config.model.lower() == "align_tts":
|
||||
model = AlignTTSLoss(config)
|
||||
else:
|
||||
raise ValueError(f" [!] loss for model {config.model.lower()} cannot be found.")
|
||||
return model
|
|
@ -12,7 +12,8 @@ def squeeze(x, x_mask=None, num_sqz=2):
|
|||
|
||||
Note:
|
||||
each 's' is a n-dimensional vector.
|
||||
[s1,s2,s3,s4,s5,s6] --> [[s1, s3, s5], [s2, s4, s6]]"""
|
||||
``[s1,s2,s3,s4,s5,s6] --> [[s1, s3, s5], [s2, s4, s6]]``
|
||||
"""
|
||||
b, c, t = x.size()
|
||||
|
||||
t = (t // num_sqz) * num_sqz
|
||||
|
@ -32,7 +33,8 @@ def unsqueeze(x, x_mask=None, num_sqz=2):
|
|||
|
||||
Note:
|
||||
each 's' is a n-dimensional vector.
|
||||
[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]]"""
|
||||
``[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]]``
|
||||
"""
|
||||
b, c, t = x.size()
|
||||
|
||||
x_unsqz = x.view(b, num_sqz, c // num_sqz, t)
|
||||
|
@ -47,7 +49,10 @@ def unsqueeze(x, x_mask=None, num_sqz=2):
|
|||
|
||||
class Decoder(nn.Module):
|
||||
"""Stack of Glow Decoder Modules.
|
||||
Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze
|
||||
|
||||
::
|
||||
|
||||
Squeeze -> ActNorm -> InvertibleConv1x1 -> AffineCoupling -> Unsqueeze
|
||||
|
||||
Args:
|
||||
in_channels (int): channels of input tensor.
|
||||
|
@ -106,6 +111,12 @@ class Decoder(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1 ,T]`
|
||||
- g: :math:`[B, C]`
|
||||
"""
|
||||
if not reverse:
|
||||
flows = self.flows
|
||||
logdet_tot = 0
|
||||
|
|
|
@ -6,13 +6,16 @@ from ..generic.normalization import LayerNorm
|
|||
|
||||
class DurationPredictor(nn.Module):
|
||||
"""Glow-TTS duration prediction model.
|
||||
[2 x (conv1d_kxk -> relu -> layer_norm -> dropout)] -> conv1d_1x1 -> durs
|
||||
|
||||
Args:
|
||||
in_channels ([type]): [description]
|
||||
hidden_channels ([type]): [description]
|
||||
kernel_size ([type]): [description]
|
||||
dropout_p ([type]): [description]
|
||||
::
|
||||
|
||||
[2 x (conv1d_kxk -> relu -> layer_norm -> dropout)] -> conv1d_1x1 -> durs
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channels of the input tensor.
|
||||
hidden_channels (int): Number of hidden channels of the network.
|
||||
kernel_size (int): Kernel size for the conv layers.
|
||||
dropout_p (float): Dropout rate used after each conv layer.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p):
|
||||
|
@ -34,11 +37,8 @@ class DurationPredictor(nn.Module):
|
|||
def forward(self, x, x_mask):
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
x_mask: [B, 1, T]
|
||||
|
||||
Returns:
|
||||
[type]: [description]
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1, T]`
|
||||
"""
|
||||
x = self.conv_1(x * x_mask)
|
||||
x = torch.relu(x)
|
||||
|
|
|
@ -9,19 +9,22 @@ from TTS.tts.layers.generic.time_depth_sep_conv import TimeDepthSeparableConvBlo
|
|||
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.glow_tts.glow import ResidualConv1dLayerNormBlock
|
||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Glow-TTS encoder module.
|
||||
|
||||
embedding -> <prenet> -> encoder_module -> <postnet> --> proj_mean
|
||||
|
|
||||
|-> proj_var
|
||||
|
|
||||
|-> concat -> duration_predictor
|
||||
↑
|
||||
speaker_embed
|
||||
::
|
||||
|
||||
embedding -> <prenet> -> encoder_module -> <postnet> --> proj_mean
|
||||
|
|
||||
|-> proj_var
|
||||
|
|
||||
|-> concat -> duration_predictor
|
||||
↑
|
||||
speaker_embed
|
||||
|
||||
Args:
|
||||
num_chars (int): number of characters.
|
||||
out_channels (int): number of output channels.
|
||||
|
@ -36,7 +39,8 @@ class Encoder(nn.Module):
|
|||
Shapes:
|
||||
- input: (B, T, C)
|
||||
|
||||
Notes:
|
||||
::
|
||||
|
||||
suggested encoder params...
|
||||
|
||||
for encoder_type == 'rel_pos_transformer'
|
||||
|
@ -139,9 +143,9 @@ class Encoder(nn.Module):
|
|||
def forward(self, x, x_lengths, g=None):
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
x_lengths: [B]
|
||||
g (optional): [B, 1, T]
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_lengths: :math:`[B]`
|
||||
- g (optional): :math:`[B, 1, T]`
|
||||
"""
|
||||
# embedding layer
|
||||
# [B ,T, D]
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from distutils.version import LooseVersion
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
@ -8,21 +10,24 @@ from ..generic.normalization import LayerNorm
|
|||
|
||||
|
||||
class ResidualConv1dLayerNormBlock(nn.Module):
|
||||
"""Conv1d with Layer Normalization and residual connection as in GlowTTS paper.
|
||||
https://arxiv.org/pdf/1811.00002.pdf
|
||||
|
||||
::
|
||||
|
||||
x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o
|
||||
|---------------> conv1d_1x1 -----------------------|
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input tensor channels.
|
||||
hidden_channels (int): number of inner layer channels.
|
||||
out_channels (int): number of output tensor channels.
|
||||
kernel_size (int): kernel size of conv1d filter.
|
||||
num_layers (int): number of blocks.
|
||||
dropout_p (float): dropout rate for each block.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, num_layers, dropout_p):
|
||||
"""Conv1d with Layer Normalization and residual connection as in GlowTTS paper.
|
||||
https://arxiv.org/pdf/1811.00002.pdf
|
||||
|
||||
x |-> conv1d -> layer_norm -> relu -> dropout -> + -> o
|
||||
|---------------> conv1d_1x1 -----------------------|
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input tensor channels.
|
||||
hidden_channels (int): number of inner layer channels.
|
||||
out_channels (int): number of output tensor channels.
|
||||
kernel_size (int): kernel size of conv1d filter.
|
||||
num_layers (int): number of blocks.
|
||||
dropout_p (float): dropout rate for each block.
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
|
@ -49,6 +54,11 @@ class ResidualConv1dLayerNormBlock(nn.Module):
|
|||
self.proj.bias.data.zero_()
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1, T]`
|
||||
"""
|
||||
x_res = x
|
||||
for i in range(self.num_layers):
|
||||
x = self.conv_layers[i](x * x_mask)
|
||||
|
@ -81,7 +91,11 @@ class InvConvNear(nn.Module):
|
|||
self.no_jacobian = no_jacobian
|
||||
self.weight_inv = None
|
||||
|
||||
w_init = torch.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0]
|
||||
if LooseVersion(torch.__version__) < LooseVersion("1.9"):
|
||||
w_init = torch.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0]
|
||||
else:
|
||||
w_init = torch.linalg.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_(), "complete")[0]
|
||||
|
||||
if torch.det(w_init) < 0:
|
||||
w_init[:, 0] = -1 * w_init[:, 0]
|
||||
self.weight = nn.Parameter(w_init)
|
||||
|
@ -89,8 +103,8 @@ class InvConvNear(nn.Module):
|
|||
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
x: B x C x T
|
||||
x_mask: B x 1 x T
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1, T]`
|
||||
"""
|
||||
|
||||
b, c, t = x.size()
|
||||
|
@ -133,10 +147,12 @@ class CouplingBlock(nn.Module):
|
|||
"""Glow Affine Coupling block as in GlowTTS paper.
|
||||
https://arxiv.org/pdf/1811.00002.pdf
|
||||
|
||||
x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o
|
||||
'-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^
|
||||
::
|
||||
|
||||
Args:
|
||||
x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o
|
||||
'-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input tensor channels.
|
||||
hidden_channels (int): number of hidden channels.
|
||||
kernel_size (int): WaveNet filter kernel size.
|
||||
|
@ -146,8 +162,8 @@ class CouplingBlock(nn.Module):
|
|||
dropout_p (int): wavenet dropout rate.
|
||||
sigmoid_scale (bool): enable/disable sigmoid scaling for output scale.
|
||||
|
||||
Note:
|
||||
It does not use conditional inputs differently from WaveGlow.
|
||||
Note:
|
||||
It does not use the conditional inputs differently from WaveGlow.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -187,9 +203,9 @@ class CouplingBlock(nn.Module):
|
|||
def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
x: B x C x T
|
||||
x_mask: B x 1 x T
|
||||
g: B x C x 1
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1, T]`
|
||||
- g: :math:`[B, C, 1]`
|
||||
"""
|
||||
if x_mask is None:
|
||||
x_mask = 1
|
||||
|
|
|
@ -2,7 +2,7 @@ import numpy as np
|
|||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
|
||||
try:
|
||||
# TODO: fix pypi cython installation problem.
|
||||
|
|
|
@ -17,16 +17,18 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
|
||||
Note:
|
||||
Example with relative attention window size 2
|
||||
input = [a, b, c, d, e]
|
||||
rel_attn_embeddings = [e(t-2), e(t-1), e(t+1), e(t+2)]
|
||||
|
||||
- input = [a, b, c, d, e]
|
||||
- rel_attn_embeddings = [e(t-2), e(t-1), e(t+1), e(t+2)]
|
||||
|
||||
So it learns 4 embedding vectors (in total 8) separately for key and value vectors.
|
||||
|
||||
Considering the input c
|
||||
e(t-2) corresponds to c -> a
|
||||
e(t-2) corresponds to c -> b
|
||||
e(t-2) corresponds to c -> d
|
||||
e(t-2) corresponds to c -> e
|
||||
|
||||
- e(t-2) corresponds to c -> a
|
||||
- e(t-2) corresponds to c -> b
|
||||
- e(t-2) corresponds to c -> d
|
||||
- e(t-2) corresponds to c -> e
|
||||
|
||||
These embeddings are shared among different time steps. So input a, b, d and e also uses
|
||||
the same embeddings.
|
||||
|
@ -106,6 +108,12 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
nn.init.xavier_uniform_(self.conv_v.weight)
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
"""
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T]`
|
||||
- c: :math:`[B, C, T]`
|
||||
- attn_mask: :math:`[B, 1, T, T]`
|
||||
"""
|
||||
q = self.conv_q(x)
|
||||
k = self.conv_k(c)
|
||||
v = self.conv_v(c)
|
||||
|
@ -163,9 +171,9 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
re (Tensor): relative value embedding vector. (a_(i,j)^V)
|
||||
|
||||
Shapes:
|
||||
p_attn: [B, H, T, V]
|
||||
re: [H or 1, V, D]
|
||||
logits: [B, H, T, D]
|
||||
-p_attn: :math:`[B, H, T, V]`
|
||||
-re: :math:`[H or 1, V, D]`
|
||||
-logits: :math:`[B, H, T, D]`
|
||||
"""
|
||||
logits = torch.matmul(p_attn, re.unsqueeze(0))
|
||||
return logits
|
||||
|
@ -178,9 +186,9 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
re (Tensor): relative key embedding vector. (a_(i,j)^K)
|
||||
|
||||
Shapes:
|
||||
query: [B, H, T, D]
|
||||
re: [H or 1, V, D]
|
||||
logits: [B, H, T, V]
|
||||
- query: :math:`[B, H, T, D]`
|
||||
- re: :math:`[H or 1, V, D]`
|
||||
- logits: :math:`[B, H, T, V]`
|
||||
"""
|
||||
# logits = torch.einsum('bhld, kmd -> bhlm', [query, re.to(query.dtype)])
|
||||
logits = torch.matmul(query, re.unsqueeze(0).transpose(-2, -1))
|
||||
|
@ -202,10 +210,10 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
@staticmethod
|
||||
def _relative_position_to_absolute_position(x):
|
||||
"""Converts tensor from relative to absolute indexing for local attention.
|
||||
Args:
|
||||
x: [B, D, length, 2 * length - 1]
|
||||
Shapes:
|
||||
x: :math:`[B, C, T, 2 * T - 1]`
|
||||
Returns:
|
||||
A Tensor of shape [B, D, length, length]
|
||||
A Tensor of shape :math:`[B, C, T, T]`
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
# Pad to shift from relative to absolute indexing.
|
||||
|
@ -220,8 +228,9 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
@staticmethod
|
||||
def _absolute_position_to_relative_position(x):
|
||||
"""
|
||||
x: [B, H, T, T]
|
||||
ret: [B, H, T, 2*T-1]
|
||||
Shapes:
|
||||
- x: :math:`[B, C, T, T]`
|
||||
- ret: :math:`[B, C, T, 2*T-1]`
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
# padd along column
|
||||
|
@ -239,7 +248,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
Args:
|
||||
length (int): an integer scalar.
|
||||
Returns:
|
||||
a Tensor with shape [1, 1, length, length]
|
||||
a Tensor with shape :math:`[1, 1, T, T]`
|
||||
"""
|
||||
# L
|
||||
r = torch.arange(length, dtype=torch.float32)
|
||||
|
@ -362,8 +371,8 @@ class RelativePositionTransformer(nn.Module):
|
|||
def forward(self, x, x_mask):
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, C, T]
|
||||
x_mask: [B, 1, T]
|
||||
- x: :math:`[B, C, T]`
|
||||
- x_mask: :math:`[B, 1, T]`
|
||||
"""
|
||||
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
for i in range(self.num_layers):
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn import functional
|
||||
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.ssim import ssim
|
||||
|
||||
|
||||
|
@ -462,13 +462,12 @@ class MDNLoss(nn.Module):
|
|||
|
||||
class AlignTTSLoss(nn.Module):
|
||||
"""Modified AlignTTS Loss.
|
||||
Computes following losses
|
||||
Computes
|
||||
- L1 and SSIM losses from output spectrograms.
|
||||
- Huber loss for duration predictor.
|
||||
- MDNLoss for Mixture of Density Network.
|
||||
|
||||
All the losses are aggregated by a weighted sum with the loss alphas.
|
||||
Alphas can be scheduled based on number of steps.
|
||||
All loss values are aggregated by a weighted sum of the alpha values.
|
||||
|
||||
Args:
|
||||
c (dict): TTS model configuration.
|
||||
|
@ -487,9 +486,9 @@ class AlignTTSLoss(nn.Module):
|
|||
self.mdn_alpha = c.mdn_alpha
|
||||
|
||||
def forward(
|
||||
self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step, phase
|
||||
self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, phase
|
||||
):
|
||||
ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(step)
|
||||
# ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(step)
|
||||
spec_loss, ssim_loss, dur_loss, mdn_loss = 0, 0, 0, 0
|
||||
if phase == 0:
|
||||
mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens)
|
||||
|
@ -507,36 +506,10 @@ class AlignTTSLoss(nn.Module):
|
|||
spec_loss = self.spec_loss(decoder_output, decoder_target, decoder_output_lens)
|
||||
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||
dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens)
|
||||
loss = spec_loss_alpha * spec_loss + ssim_alpha * ssim_loss + dur_loss_alpha * dur_loss + mdn_alpha * mdn_loss
|
||||
loss = (
|
||||
self.spec_loss_alpha * spec_loss
|
||||
+ self.ssim_alpha * ssim_loss
|
||||
+ self.dur_loss_alpha * dur_loss
|
||||
+ self.mdn_alpha * mdn_loss
|
||||
)
|
||||
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
|
||||
|
||||
@staticmethod
|
||||
def _set_alpha(step, alpha_settings):
|
||||
"""Set the loss alpha wrt number of steps.
|
||||
Return the corresponding value if no schedule is set.
|
||||
|
||||
Example:
|
||||
Setting a alpha schedule.
|
||||
if ```alpha_settings``` is ```[[0, 1], [10000, 0.1]]``` then ```return_alpha == 1``` until 10k steps, then set to 0.1.
|
||||
if ```alpha_settings``` is a constant value then ```return_alpha``` is set to that constant.
|
||||
|
||||
Args:
|
||||
step (int): number of training steps.
|
||||
alpha_settings (int or list): constant alpha value or a list defining the schedule as explained above.
|
||||
"""
|
||||
return_alpha = None
|
||||
if isinstance(alpha_settings, list):
|
||||
for key, alpha in alpha_settings:
|
||||
if key < step:
|
||||
return_alpha = alpha
|
||||
elif isinstance(alpha_settings, (float, int)):
|
||||
return_alpha = alpha_settings
|
||||
return return_alpha
|
||||
|
||||
def set_alphas(self, step):
|
||||
"""Set the alpha values for all the loss functions"""
|
||||
ssim_alpha = self._set_alpha(step, self.ssim_alpha)
|
||||
dur_loss_alpha = self._set_alpha(step, self.dur_loss_alpha)
|
||||
spec_loss_alpha = self._set_alpha(step, self.spec_loss_alpha)
|
||||
mdn_alpha = self._set_alpha(step, self.mdn_alpha)
|
||||
return ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha
|
||||
|
|
|
@ -8,10 +8,10 @@ class GST(nn.Module):
|
|||
|
||||
See https://arxiv.org/pdf/1803.09017"""
|
||||
|
||||
def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, speaker_embedding_dim=None):
|
||||
def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim=None):
|
||||
super().__init__()
|
||||
self.encoder = ReferenceEncoder(num_mel, gst_embedding_dim)
|
||||
self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, speaker_embedding_dim)
|
||||
self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, d_vector_dim)
|
||||
|
||||
def forward(self, inputs, speaker_embedding=None):
|
||||
enc_out = self.encoder(inputs)
|
||||
|
@ -83,13 +83,13 @@ class ReferenceEncoder(nn.Module):
|
|||
class StyleTokenLayer(nn.Module):
|
||||
"""NN Module attending to style tokens based on prosody encodings."""
|
||||
|
||||
def __init__(self, num_heads, num_style_tokens, embedding_dim, speaker_embedding_dim=None):
|
||||
def __init__(self, num_heads, num_style_tokens, embedding_dim, d_vector_dim=None):
|
||||
super().__init__()
|
||||
|
||||
self.query_dim = embedding_dim // 2
|
||||
|
||||
if speaker_embedding_dim:
|
||||
self.query_dim += speaker_embedding_dim
|
||||
if d_vector_dim:
|
||||
self.query_dim += d_vector_dim
|
||||
|
||||
self.key_dim = embedding_dim // num_heads
|
||||
self.style_tokens = nn.Parameter(torch.FloatTensor(num_style_tokens, self.key_dim))
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
# coding: utf-8
|
||||
# adapted from https://github.com/r9y9/tacotron_pytorch
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
@ -266,7 +268,8 @@ class Decoder(nn.Module):
|
|||
location_attn (bool): if true, use location sensitive attention.
|
||||
attn_K (int): number of attention heads for GravesAttention.
|
||||
separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow.
|
||||
speaker_embedding_dim (int): size of speaker embedding vector, for multi-speaker training.
|
||||
d_vector_dim (int): size of speaker embedding vector, for multi-speaker training.
|
||||
max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 500.
|
||||
"""
|
||||
|
||||
# Pylint gets confused by PyTorch conventions here
|
||||
|
@ -289,12 +292,13 @@ class Decoder(nn.Module):
|
|||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
max_decoder_steps,
|
||||
):
|
||||
super().__init__()
|
||||
self.r_init = r
|
||||
self.r = r
|
||||
self.in_channels = in_channels
|
||||
self.max_decoder_steps = 500
|
||||
self.max_decoder_steps = max_decoder_steps
|
||||
self.use_memory_queue = memory_size > 0
|
||||
self.memory_size = memory_size if memory_size > 0 else r
|
||||
self.frame_channels = frame_channels
|
||||
|
|
|
@ -135,6 +135,7 @@ class Decoder(nn.Module):
|
|||
location_attn (bool): if true, use location sensitive attention.
|
||||
attn_K (int): number of attention heads for GravesAttention.
|
||||
separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow.
|
||||
max_decoder_steps (int): Maximum number of steps allowed for the decoder. Defaults to 10000.
|
||||
"""
|
||||
|
||||
# Pylint gets confused by PyTorch conventions here
|
||||
|
@ -155,6 +156,7 @@ class Decoder(nn.Module):
|
|||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
max_decoder_steps,
|
||||
):
|
||||
super().__init__()
|
||||
self.frame_channels = frame_channels
|
||||
|
@ -162,7 +164,7 @@ class Decoder(nn.Module):
|
|||
self.r = r
|
||||
self.encoder_embedding_dim = in_channels
|
||||
self.separate_stopnet = separate_stopnet
|
||||
self.max_decoder_steps = 1000
|
||||
self.max_decoder_steps = max_decoder_steps
|
||||
self.stop_threshold = 0.5
|
||||
|
||||
# model dimensions
|
||||
|
@ -355,7 +357,7 @@ class Decoder(nn.Module):
|
|||
if stop_token > self.stop_threshold and t > inputs.shape[0] // 2:
|
||||
break
|
||||
if len(outputs) == self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
print(f" > Decoder stopped with `max_decoder_steps` {self.max_decoder_steps}")
|
||||
break
|
||||
|
||||
memory = self._update_memory(decoder_output)
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
from TTS.tts.utils.text.symbols import make_symbols, parse_symbols
|
||||
from TTS.utils.generic_utils import find_module
|
||||
|
||||
|
||||
def setup_model(config):
|
||||
print(" > Using model: {}".format(config.model))
|
||||
|
||||
MyModel = find_module("TTS.tts.models", config.model.lower())
|
||||
# define set of characters used by the model
|
||||
if config.characters is not None:
|
||||
# set characters from config
|
||||
symbols, phonemes = make_symbols(**config.characters.to_dict()) # pylint: disable=redefined-outer-name
|
||||
else:
|
||||
from TTS.tts.utils.text.symbols import phonemes, symbols # pylint: disable=import-outside-toplevel
|
||||
|
||||
# use default characters and assign them to config
|
||||
config.characters = parse_symbols()
|
||||
num_chars = len(phonemes) if config.use_phonemes else len(symbols)
|
||||
# consider special `blank` character if `add_blank` is set True
|
||||
num_chars = num_chars + getattr(config, "add_blank", False)
|
||||
config.num_chars = num_chars
|
||||
# compatibility fix
|
||||
if "model_params" in config:
|
||||
config.model_params.num_chars = num_chars
|
||||
if "model_args" in config:
|
||||
config.model_args.num_chars = num_chars
|
||||
model = MyModel(config)
|
||||
return model
|
||||
|
||||
|
||||
# TODO; class registery
|
||||
# def import_models(models_dir, namespace):
|
||||
# for file in os.listdir(models_dir):
|
||||
# path = os.path.join(models_dir, file)
|
||||
# if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)):
|
||||
# model_name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||
# importlib.import_module(namespace + "." + model_name)
|
||||
#
|
||||
#
|
||||
## automatically import any Python files in the models/ directory
|
||||
# models_dir = os.path.dirname(__file__)
|
||||
# import_models(models_dir, "TTS.tts.models")
|
|
@ -1,5 +1,9 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.tts.layers.align_tts.mdn import MDNBlock
|
||||
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||
|
@ -7,32 +11,16 @@ from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
|||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
class AlignTTS(nn.Module):
|
||||
"""AlignTTS with modified duration predictor.
|
||||
https://arxiv.org/pdf/2003.01950.pdf
|
||||
|
||||
Encoder -> DurationPredictor -> Decoder
|
||||
|
||||
AlignTTS's Abstract - Targeting at both high efficiency and performance, we propose AlignTTS to predict the
|
||||
mel-spectrum in parallel. AlignTTS is based on a Feed-Forward Transformer which generates mel-spectrum from a
|
||||
sequence of characters, and the duration of each character is determined by a duration predictor.Instead of
|
||||
adopting the attention mechanism in Transformer TTS to align text to mel-spectrum, the alignment loss is presented
|
||||
to consider all possible alignments in training by use of dynamic programming. Experiments on the LJSpeech dataset s
|
||||
how that our model achieves not only state-of-the-art performance which outperforms Transformer TTS by 0.03 in mean
|
||||
option score (MOS), but also a high efficiency which is more than 50 times faster than real-time.
|
||||
|
||||
Note:
|
||||
Original model uses a separate character embedding layer for duration predictor. However, it causes the
|
||||
duration predictor to overfit and prevents learning higher level interactions among characters. Therefore,
|
||||
we predict durations based on encoder outputs which has higher level information about input characters. This
|
||||
enables training without phases as in the original paper.
|
||||
|
||||
Original model uses Transormers in encoder and decoder layers. However, here you can set the architecture
|
||||
differently based on your requirements using ```encoder_type``` and ```decoder_type``` parameters.
|
||||
|
||||
@dataclass
|
||||
class AlignTTSArgs(Coqpit):
|
||||
"""
|
||||
Args:
|
||||
num_chars (int):
|
||||
number of unique input to characters
|
||||
|
@ -60,42 +48,102 @@ class AlignTTS(nn.Module):
|
|||
number of channels in speaker embedding vectors. Defaults to 0.
|
||||
"""
|
||||
|
||||
num_chars: int = None
|
||||
out_channels: int = 80
|
||||
hidden_channels: int = 256
|
||||
hidden_channels_dp: int = 256
|
||||
encoder_type: str = "fftransformer"
|
||||
encoder_params: dict = field(
|
||||
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}
|
||||
)
|
||||
decoder_type: str = "fftransformer"
|
||||
decoder_params: dict = field(
|
||||
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}
|
||||
)
|
||||
length_scale: float = 1.0
|
||||
num_speakers: int = 0
|
||||
use_speaker_embedding: bool = False
|
||||
use_d_vector_file: bool = False
|
||||
d_vector_dim: int = 0
|
||||
|
||||
|
||||
class AlignTTS(BaseTTS):
|
||||
"""AlignTTS with modified duration predictor.
|
||||
https://arxiv.org/pdf/2003.01950.pdf
|
||||
|
||||
Encoder -> DurationPredictor -> Decoder
|
||||
|
||||
Check :class:`AlignTTSArgs` for the class arguments.
|
||||
|
||||
Paper Abstract:
|
||||
Targeting at both high efficiency and performance, we propose AlignTTS to predict the
|
||||
mel-spectrum in parallel. AlignTTS is based on a Feed-Forward Transformer which generates mel-spectrum from a
|
||||
sequence of characters, and the duration of each character is determined by a duration predictor.Instead of
|
||||
adopting the attention mechanism in Transformer TTS to align text to mel-spectrum, the alignment loss is presented
|
||||
to consider all possible alignments in training by use of dynamic programming. Experiments on the LJSpeech dataset s
|
||||
how that our model achieves not only state-of-the-art performance which outperforms Transformer TTS by 0.03 in mean
|
||||
option score (MOS), but also a high efficiency which is more than 50 times faster than real-time.
|
||||
|
||||
Note:
|
||||
Original model uses a separate character embedding layer for duration predictor. However, it causes the
|
||||
duration predictor to overfit and prevents learning higher level interactions among characters. Therefore,
|
||||
we predict durations based on encoder outputs which has higher level information about input characters. This
|
||||
enables training without phases as in the original paper.
|
||||
|
||||
Original model uses Transormers in encoder and decoder layers. However, here you can set the architecture
|
||||
differently based on your requirements using ```encoder_type``` and ```decoder_type``` parameters.
|
||||
|
||||
Examples:
|
||||
>>> from TTS.tts.configs import AlignTTSConfig
|
||||
>>> config = AlignTTSConfig()
|
||||
>>> model = AlignTTS(config)
|
||||
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
out_channels,
|
||||
hidden_channels=256,
|
||||
hidden_channels_dp=256,
|
||||
encoder_type="fftransformer",
|
||||
encoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
|
||||
decoder_type="fftransformer",
|
||||
decoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
|
||||
length_scale=1,
|
||||
num_speakers=0,
|
||||
external_c=False,
|
||||
c_in_channels=0,
|
||||
):
|
||||
def __init__(self, config: Coqpit):
|
||||
|
||||
super().__init__()
|
||||
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
|
||||
self.emb = nn.Embedding(num_chars, hidden_channels)
|
||||
self.pos_encoder = PositionalEncoding(hidden_channels)
|
||||
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels)
|
||||
self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params)
|
||||
self.duration_predictor = DurationPredictor(hidden_channels_dp)
|
||||
self.config = config
|
||||
self.phase = -1
|
||||
self.length_scale = (
|
||||
float(config.model_args.length_scale)
|
||||
if isinstance(config.model_args.length_scale, int)
|
||||
else config.model_args.length_scale
|
||||
)
|
||||
|
||||
self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||
self.mdn_block = MDNBlock(hidden_channels, 2 * out_channels)
|
||||
if not self.config.model_args.num_chars:
|
||||
_, self.config, num_chars = self.get_characters(config)
|
||||
self.config.model_args.num_chars = num_chars
|
||||
|
||||
if num_speakers > 1 and not external_c:
|
||||
# speaker embedding layer
|
||||
self.emb_g = nn.Embedding(num_speakers, c_in_channels)
|
||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||
self.emb = nn.Embedding(self.config.model_args.num_chars, self.config.model_args.hidden_channels)
|
||||
|
||||
if c_in_channels > 0 and c_in_channels != hidden_channels:
|
||||
self.proj_g = nn.Conv1d(c_in_channels, hidden_channels, 1)
|
||||
self.embedded_speaker_dim = 0
|
||||
self.init_multispeaker(config)
|
||||
|
||||
self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels)
|
||||
self.encoder = Encoder(
|
||||
config.model_args.hidden_channels,
|
||||
config.model_args.hidden_channels,
|
||||
config.model_args.encoder_type,
|
||||
config.model_args.encoder_params,
|
||||
self.embedded_speaker_dim,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
config.model_args.out_channels,
|
||||
config.model_args.hidden_channels,
|
||||
config.model_args.decoder_type,
|
||||
config.model_args.decoder_params,
|
||||
)
|
||||
self.duration_predictor = DurationPredictor(config.model_args.hidden_channels_dp)
|
||||
|
||||
self.mod_layer = nn.Conv1d(config.model_args.hidden_channels, config.model_args.hidden_channels, 1)
|
||||
|
||||
self.mdn_block = MDNBlock(config.model_args.hidden_channels, 2 * config.model_args.out_channels)
|
||||
|
||||
if self.embedded_speaker_dim > 0 and self.embedded_speaker_dim != config.model_args.hidden_channels:
|
||||
self.proj_g = nn.Conv1d(self.embedded_speaker_dim, config.model_args.hidden_channels, 1)
|
||||
|
||||
@staticmethod
|
||||
def compute_log_probs(mu, log_sigma, y):
|
||||
|
@ -129,15 +177,15 @@ class AlignTTS(nn.Module):
|
|||
"""Generate attention alignment map from durations and
|
||||
expand encoder outputs
|
||||
|
||||
Example:
|
||||
encoder output: [a,b,c,d]
|
||||
durations: [1, 3, 2, 1]
|
||||
Examples::
|
||||
- encoder output: [a,b,c,d]
|
||||
- durations: [1, 3, 2, 1]
|
||||
|
||||
expanded: [a, b, b, b, c, c, d]
|
||||
attention map: [[0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1, 1, 0],
|
||||
[0, 1, 1, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0, 0]]
|
||||
- expanded: [a, b, b, b, c, c, d]
|
||||
- attention map: [[0, 0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1, 1, 0],
|
||||
[0, 1, 1, 1, 0, 0, 0],
|
||||
[1, 0, 0, 0, 0, 0, 0]]
|
||||
"""
|
||||
attn = self.convert_dr_to_align(dr, x_mask, y_mask)
|
||||
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
|
||||
|
@ -159,11 +207,12 @@ class AlignTTS(nn.Module):
|
|||
# project g to decoder dim.
|
||||
if hasattr(self, "proj_g"):
|
||||
g = self.proj_g(g)
|
||||
|
||||
return x + g
|
||||
|
||||
def _forward_encoder(self, x, x_lengths, g=None):
|
||||
if hasattr(self, "emb_g"):
|
||||
g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1]
|
||||
g = nn.functional.normalize(self.speaker_embedding(g)) # [B, C, 1]
|
||||
|
||||
if g is not None:
|
||||
g = g.unsqueeze(-1)
|
||||
|
@ -207,15 +256,19 @@ class AlignTTS(nn.Module):
|
|||
dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask, y_mask)
|
||||
return dr_mas, mu, log_sigma, logp
|
||||
|
||||
def forward(self, x, x_lengths, y, y_lengths, phase=None, g=None): # pylint: disable=unused-argument
|
||||
def forward(
|
||||
self, x, x_lengths, y, y_lengths, aux_input={"d_vectors": None}, phase=None
|
||||
): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, T_max]
|
||||
x_lengths: [B]
|
||||
y_lengths: [B]
|
||||
dr: [B, T_max]
|
||||
g: [B, C]
|
||||
- x: :math:`[B, T_max]`
|
||||
- x_lengths: :math:`[B]`
|
||||
- y_lengths: :math:`[B]`
|
||||
- dr: :math:`[B, T_max]`
|
||||
- g: :math:`[B, C]`
|
||||
"""
|
||||
y = y.transpose(1, 2)
|
||||
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||
o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp = None, None, None, None, None, None, None
|
||||
if phase == 0:
|
||||
# train encoder and MDN
|
||||
|
@ -247,16 +300,27 @@ class AlignTTS(nn.Module):
|
|||
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
|
||||
o_dr_log = o_dr_log.squeeze(1)
|
||||
dr_mas_log = torch.log(dr_mas + 1).squeeze(1)
|
||||
return o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp
|
||||
outputs = {
|
||||
"model_outputs": o_de.transpose(1, 2),
|
||||
"alignments": attn,
|
||||
"durations_log": o_dr_log,
|
||||
"durations_mas_log": dr_mas_log,
|
||||
"mu": mu,
|
||||
"log_sigma": log_sigma,
|
||||
"logp": logp,
|
||||
}
|
||||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, x_lengths, g=None): # pylint: disable=unused-argument
|
||||
def inference(self, x, aux_input={"d_vectors": None}): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, T_max]
|
||||
x_lengths: [B]
|
||||
g: [B, C]
|
||||
- x: :math:`[B, T_max]`
|
||||
- x_lengths: :math:`[B]`
|
||||
- g: :math:`[B, C]`
|
||||
"""
|
||||
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||
# pad input to prevent dropping the last word
|
||||
# x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
|
@ -266,7 +330,61 @@ class AlignTTS(nn.Module):
|
|||
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||
y_lengths = o_dr.sum(1)
|
||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
|
||||
return o_de, attn
|
||||
outputs = {"model_outputs": o_de.transpose(1, 2), "alignments": attn}
|
||||
return outputs
|
||||
|
||||
def train_step(self, batch: dict, criterion: nn.Module):
|
||||
text_input = batch["text_input"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
mel_input = batch["mel_input"]
|
||||
mel_lengths = batch["mel_lengths"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
|
||||
aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input, self.phase)
|
||||
loss_dict = criterion(
|
||||
outputs["logp"],
|
||||
outputs["model_outputs"],
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
outputs["durations_log"],
|
||||
outputs["durations_mas_log"],
|
||||
text_lengths,
|
||||
phase=self.phase,
|
||||
)
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True)
|
||||
loss_dict["align_error"] = align_error
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(
|
||||
self, ap: AudioProcessor, batch: dict, outputs: dict
|
||||
) -> Tuple[Dict, Dict]: # pylint: disable=no-self-use
|
||||
model_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
mel_input = batch["mel_input"]
|
||||
|
||||
pred_spec = model_outputs[0].data.cpu().numpy()
|
||||
gt_spec = mel_input[0].data.cpu().numpy()
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
|
||||
figures = {
|
||||
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||
"alignment": plot_alignment(align_img, output_fig=False),
|
||||
}
|
||||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
return figures, {"audio": train_audio}
|
||||
|
||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||
return self.train_log(ap, batch, outputs)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
|
@ -276,3 +394,29 @@ class AlignTTS(nn.Module):
|
|||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
def get_criterion(self):
|
||||
from TTS.tts.layers.losses import AlignTTSLoss # pylint: disable=import-outside-toplevel
|
||||
|
||||
return AlignTTSLoss(self.config)
|
||||
|
||||
@staticmethod
|
||||
def _set_phase(config, global_step):
|
||||
"""Decide AlignTTS training phase"""
|
||||
if isinstance(config.phase_start_steps, list):
|
||||
vals = [i < global_step for i in config.phase_start_steps]
|
||||
if not True in vals:
|
||||
phase = 0
|
||||
else:
|
||||
phase = (
|
||||
len(config.phase_start_steps)
|
||||
- [i < global_step for i in config.phase_start_steps][::-1].index(True)
|
||||
- 1
|
||||
)
|
||||
else:
|
||||
phase = None
|
||||
return phase
|
||||
|
||||
def on_epoch_start(self, trainer):
|
||||
"""Set AlignTTS training phase on epoch start."""
|
||||
self.phase = self._set_phase(trainer.config, trainer.total_steps_done)
|
||||
|
|
|
@ -0,0 +1,284 @@
|
|||
import copy
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from coqpit import MISSING, Coqpit
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.losses import TacotronLoss
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||
from TTS.tts.utils.text import make_symbols
|
||||
from TTS.utils.generic_utils import format_aux_input
|
||||
from TTS.utils.training import gradual_training_scheduler
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseTacotronArgs(Coqpit):
|
||||
"""TODO: update Tacotron configs using it"""
|
||||
|
||||
num_chars: int = MISSING
|
||||
num_speakers: int = MISSING
|
||||
r: int = MISSING
|
||||
out_channels: int = 80
|
||||
decoder_output_dim: int = 80
|
||||
attn_type: str = "original"
|
||||
attn_win: bool = False
|
||||
attn_norm: str = "softmax"
|
||||
prenet_type: str = "original"
|
||||
prenet_dropout: bool = True
|
||||
prenet_dropout_at_inference: bool = False
|
||||
forward_attn: bool = False
|
||||
trans_agent: bool = False
|
||||
forward_attn_mask: bool = False
|
||||
location_attn: bool = True
|
||||
attn_K: int = 5
|
||||
separate_stopnet: bool = True
|
||||
bidirectional_decoder: bool = False
|
||||
double_decoder_consistency: bool = False
|
||||
ddc_r: int = None
|
||||
encoder_in_features: int = 512
|
||||
decoder_in_features: int = 512
|
||||
d_vector_dim: int = None
|
||||
use_gst: bool = False
|
||||
gst: bool = None
|
||||
gradual_training: bool = None
|
||||
|
||||
|
||||
class BaseTacotron(BaseTTS):
|
||||
def __init__(self, config: Coqpit):
|
||||
"""Abstract Tacotron class"""
|
||||
super().__init__()
|
||||
|
||||
for key in config:
|
||||
setattr(self, key, config[key])
|
||||
|
||||
# layers
|
||||
self.embedding = None
|
||||
self.encoder = None
|
||||
self.decoder = None
|
||||
self.postnet = None
|
||||
|
||||
# init tensors
|
||||
self.embedded_speakers = None
|
||||
self.embedded_speakers_projected = None
|
||||
|
||||
# global style token
|
||||
if self.gst and self.use_gst:
|
||||
self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim
|
||||
self.gst_layer = None
|
||||
|
||||
# additional layers
|
||||
self.decoder_backward = None
|
||||
self.coarse_decoder = None
|
||||
|
||||
# init multi-speaker layers
|
||||
self.init_multispeaker(config)
|
||||
|
||||
@staticmethod
|
||||
def _format_aux_input(aux_input: Dict) -> Dict:
|
||||
return format_aux_input({"d_vectors": None, "speaker_ids": None}, aux_input)
|
||||
|
||||
#############################
|
||||
# INIT FUNCTIONS
|
||||
#############################
|
||||
|
||||
def _init_states(self):
|
||||
self.embedded_speakers = None
|
||||
self.embedded_speakers_projected = None
|
||||
|
||||
def _init_backward_decoder(self):
|
||||
self.decoder_backward = copy.deepcopy(self.decoder)
|
||||
|
||||
def _init_coarse_decoder(self):
|
||||
self.coarse_decoder = copy.deepcopy(self.decoder)
|
||||
self.coarse_decoder.r_init = self.ddc_r
|
||||
self.coarse_decoder.set_r(self.ddc_r)
|
||||
|
||||
#############################
|
||||
# CORE FUNCTIONS
|
||||
#############################
|
||||
|
||||
@abstractmethod
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def inference(self):
|
||||
pass
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if "r" in state:
|
||||
self.decoder.set_r(state["r"])
|
||||
else:
|
||||
self.decoder.set_r(state["config"]["r"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
def get_criterion(self) -> nn.Module:
|
||||
return TacotronLoss(self.config)
|
||||
|
||||
@staticmethod
|
||||
def get_characters(config: Coqpit) -> str:
|
||||
# TODO: implement CharacterProcessor
|
||||
if config.characters is not None:
|
||||
symbols, phonemes = make_symbols(**config.characters)
|
||||
else:
|
||||
from TTS.tts.utils.text.symbols import ( # pylint: disable=import-outside-toplevel
|
||||
parse_symbols,
|
||||
phonemes,
|
||||
symbols,
|
||||
)
|
||||
|
||||
config.characters = parse_symbols()
|
||||
model_characters = phonemes if config.use_phonemes else symbols
|
||||
return model_characters, config
|
||||
|
||||
@staticmethod
|
||||
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
|
||||
return get_speaker_manager(config, restore_path, data, out_path)
|
||||
|
||||
def get_aux_input(self, **kwargs) -> Dict:
|
||||
"""Compute Tacotron's auxiliary inputs based on model config.
|
||||
- speaker d_vector
|
||||
- style wav for GST
|
||||
- speaker ID for speaker embedding
|
||||
"""
|
||||
# setup speaker_id
|
||||
if self.config.use_speaker_embedding:
|
||||
speaker_id = kwargs.get("speaker_id", 0)
|
||||
else:
|
||||
speaker_id = None
|
||||
# setup d_vector
|
||||
d_vector = (
|
||||
self.speaker_manager.get_d_vectors_by_speaker(self.speaker_manager.speaker_names[0])
|
||||
if self.config.use_d_vector_file and self.config.use_speaker_embedding
|
||||
else None
|
||||
)
|
||||
# setup style_mel
|
||||
if "style_wav" in kwargs:
|
||||
style_wav = kwargs["style_wav"]
|
||||
elif self.config.has("gst_style_input"):
|
||||
style_wav = self.config.gst_style_input
|
||||
else:
|
||||
style_wav = None
|
||||
if style_wav is None and "use_gst" in self.config and self.config.use_gst:
|
||||
# inicialize GST with zero dict.
|
||||
style_wav = {}
|
||||
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
|
||||
for i in range(self.config.gst["gst_num_style_tokens"]):
|
||||
style_wav[str(i)] = 0
|
||||
aux_inputs = {"speaker_id": speaker_id, "style_wav": style_wav, "d_vector": d_vector}
|
||||
return aux_inputs
|
||||
|
||||
#############################
|
||||
# COMMON COMPUTE FUNCTIONS
|
||||
#############################
|
||||
|
||||
def compute_masks(self, text_lengths, mel_lengths):
|
||||
"""Compute masks against sequence paddings."""
|
||||
# B x T_in_max (boolean)
|
||||
input_mask = sequence_mask(text_lengths)
|
||||
output_mask = None
|
||||
if mel_lengths is not None:
|
||||
max_len = mel_lengths.max()
|
||||
r = self.decoder.r
|
||||
max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len
|
||||
output_mask = sequence_mask(mel_lengths, max_len=max_len)
|
||||
return input_mask, output_mask
|
||||
|
||||
def _backward_pass(self, mel_specs, encoder_outputs, mask):
|
||||
"""Run backwards decoder"""
|
||||
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
|
||||
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask
|
||||
)
|
||||
decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous()
|
||||
return decoder_outputs_b, alignments_b
|
||||
|
||||
def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask):
|
||||
"""Double Decoder Consistency"""
|
||||
T = mel_specs.shape[1]
|
||||
if T % self.coarse_decoder.r > 0:
|
||||
padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r)
|
||||
mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0))
|
||||
decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder(
|
||||
encoder_outputs.detach(), mel_specs, input_mask
|
||||
)
|
||||
# scale_factor = self.decoder.r_init / self.decoder.r
|
||||
alignments_backward = torch.nn.functional.interpolate(
|
||||
alignments_backward.transpose(1, 2), size=alignments.shape[1], mode="nearest"
|
||||
).transpose(1, 2)
|
||||
decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2)
|
||||
decoder_outputs_backward = decoder_outputs_backward[:, :T, :]
|
||||
return decoder_outputs_backward, alignments_backward
|
||||
|
||||
#############################
|
||||
# EMBEDDING FUNCTIONS
|
||||
#############################
|
||||
|
||||
def compute_speaker_embedding(self, speaker_ids):
|
||||
"""Compute speaker embedding vectors"""
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
||||
self.embedded_speakers = self.speaker_embedding(speaker_ids).unsqueeze(1)
|
||||
if hasattr(self, "speaker_project_mel") and speaker_ids is not None:
|
||||
self.embedded_speakers_projected = self.speaker_project_mel(self.embedded_speakers).squeeze(1)
|
||||
|
||||
def compute_gst(self, inputs, style_input, speaker_embedding=None):
|
||||
"""Compute global style token"""
|
||||
if isinstance(style_input, dict):
|
||||
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).type_as(inputs)
|
||||
if speaker_embedding is not None:
|
||||
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
|
||||
|
||||
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
|
||||
for k_token, v_amplifier in style_input.items():
|
||||
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
||||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
||||
elif style_input is None:
|
||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).type_as(inputs)
|
||||
else:
|
||||
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
|
||||
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def _add_speaker_embedding(outputs, embedded_speakers):
|
||||
embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1)
|
||||
outputs = outputs + embedded_speakers_
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _concat_speaker_embedding(outputs, embedded_speakers):
|
||||
embedded_speakers_ = embedded_speakers.expand(outputs.size(0), outputs.size(1), -1)
|
||||
outputs = torch.cat([outputs, embedded_speakers_], dim=-1)
|
||||
return outputs
|
||||
|
||||
#############################
|
||||
# CALLBACKS
|
||||
#############################
|
||||
|
||||
def on_epoch_start(self, trainer):
|
||||
"""Callback for setting values wrt gradual training schedule.
|
||||
|
||||
Args:
|
||||
trainer (TrainerTTS): TTS trainer object that is used to train this model.
|
||||
"""
|
||||
if self.gradual_training:
|
||||
r, trainer.config.batch_size = gradual_training_scheduler(trainer.total_steps_done, trainer.config)
|
||||
trainer.config.r = r
|
||||
self.decoder.set_r(r)
|
||||
if trainer.config.bidirectional_decoder:
|
||||
trainer.model.decoder_backward.set_r(r)
|
||||
print(f"\n > Number of output frames: {self.decoder.r}")
|
|
@ -0,0 +1,234 @@
|
|||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from TTS.model import BaseModel
|
||||
from TTS.tts.datasets import TTSDataset
|
||||
from TTS.tts.utils.speakers import SpeakerManager, get_speaker_manager
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text import make_symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
|
||||
class BaseTTS(BaseModel):
|
||||
"""Abstract `tts` class. Every new `tts` model must inherit this.
|
||||
|
||||
It defines `tts` specific functions on top of `Model`.
|
||||
|
||||
Notes on input/output tensor shapes:
|
||||
Any input or output tensor of the model must be shaped as
|
||||
|
||||
- 3D tensors `batch x time x channels`
|
||||
- 2D tensors `batch x channels`
|
||||
- 1D tensors `batch x 1`
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_characters(config: Coqpit) -> str:
|
||||
# TODO: implement CharacterProcessor
|
||||
if config.characters is not None:
|
||||
symbols, phonemes = make_symbols(**config.characters)
|
||||
else:
|
||||
from TTS.tts.utils.text.symbols import parse_symbols, phonemes, symbols
|
||||
|
||||
config.characters = parse_symbols()
|
||||
model_characters = phonemes if config.use_phonemes else symbols
|
||||
num_chars = len(model_characters) + getattr(config, "add_blank", False)
|
||||
return model_characters, config, num_chars
|
||||
|
||||
def get_speaker_manager(config: Coqpit, restore_path: str, data: List, out_path: str = None) -> SpeakerManager:
|
||||
return get_speaker_manager(config, restore_path, data, out_path)
|
||||
|
||||
def init_multispeaker(self, config: Coqpit, data: List = None):
|
||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||
or with external `d_vectors` computed from a speaker encoder model.
|
||||
|
||||
If you need a different behaviour, override this function for your model.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||
"""
|
||||
# init speaker manager
|
||||
self.speaker_manager = get_speaker_manager(config, data=data)
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = (
|
||||
config.d_vector_dim if "d_vector_dim" in config and config.d_vector_dim is not None else 512
|
||||
)
|
||||
self.speaker_embedding = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
|
||||
def get_aux_input(self, **kwargs) -> Dict:
|
||||
"""Prepare and return `aux_input` used by `forward()`"""
|
||||
pass
|
||||
|
||||
def format_batch(self, batch: Dict) -> Dict:
|
||||
"""Generic batch formatting for `TTSDataset`.
|
||||
|
||||
You must override this if you use a custom dataset.
|
||||
|
||||
Args:
|
||||
batch (Dict): [description]
|
||||
|
||||
Returns:
|
||||
Dict: [description]
|
||||
"""
|
||||
# setup input batch
|
||||
text_input = batch[0]
|
||||
text_lengths = batch[1]
|
||||
speaker_names = batch[2]
|
||||
linear_input = batch[3] if self.config.model.lower() in ["tacotron"] else None
|
||||
mel_input = batch[4]
|
||||
mel_lengths = batch[5]
|
||||
stop_targets = batch[6]
|
||||
item_idx = batch[7]
|
||||
d_vectors = batch[8]
|
||||
speaker_ids = batch[9]
|
||||
attn_mask = batch[10]
|
||||
max_text_length = torch.max(text_lengths.float())
|
||||
max_spec_length = torch.max(mel_lengths.float())
|
||||
|
||||
# compute durations from attention masks
|
||||
durations = None
|
||||
if attn_mask is not None:
|
||||
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
|
||||
for idx, am in enumerate(attn_mask):
|
||||
# compute raw durations
|
||||
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
|
||||
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
|
||||
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
|
||||
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
|
||||
dur[c_idxs] = counts
|
||||
# smooth the durations and set any 0 duration to 1
|
||||
# by cutting off from the largest duration indeces.
|
||||
extra_frames = dur.sum() - mel_lengths[idx]
|
||||
largest_idxs = torch.argsort(-dur)[:extra_frames]
|
||||
dur[largest_idxs] -= 1
|
||||
assert (
|
||||
dur.sum() == mel_lengths[idx]
|
||||
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
||||
durations[idx, : text_lengths[idx]] = dur
|
||||
|
||||
# set stop targets view, we predict a single stop token per iteration.
|
||||
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
||||
|
||||
return {
|
||||
"text_input": text_input,
|
||||
"text_lengths": text_lengths,
|
||||
"speaker_names": speaker_names,
|
||||
"mel_input": mel_input,
|
||||
"mel_lengths": mel_lengths,
|
||||
"linear_input": linear_input,
|
||||
"stop_targets": stop_targets,
|
||||
"attn_mask": attn_mask,
|
||||
"durations": durations,
|
||||
"speaker_ids": speaker_ids,
|
||||
"d_vectors": d_vectors,
|
||||
"max_text_length": float(max_text_length),
|
||||
"max_spec_length": float(max_spec_length),
|
||||
"item_idx": item_idx,
|
||||
}
|
||||
|
||||
def get_data_loader(
|
||||
self, config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool, num_gpus: int
|
||||
) -> "DataLoader":
|
||||
if is_eval and not config.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
# setup multi-speaker attributes
|
||||
if hasattr(self, "speaker_manager"):
|
||||
speaker_id_mapping = self.speaker_manager.speaker_ids if config.use_speaker_embedding else None
|
||||
d_vector_mapping = (
|
||||
self.speaker_manager.d_vectors
|
||||
if config.use_speaker_embedding and config.use_d_vector_file
|
||||
else None
|
||||
)
|
||||
else:
|
||||
speaker_id_mapping = None
|
||||
d_vector_mapping = None
|
||||
|
||||
# init dataloader
|
||||
dataset = TTSDataset(
|
||||
outputs_per_step=config.r if "r" in config else 1,
|
||||
text_cleaner=config.text_cleaner,
|
||||
compute_linear_spec=config.model.lower() == "tacotron",
|
||||
meta_data=data_items,
|
||||
ap=ap,
|
||||
characters=config.characters,
|
||||
add_blank=config["add_blank"],
|
||||
batch_group_size=0 if is_eval else config.batch_group_size * config.batch_size,
|
||||
min_seq_len=config.min_seq_len,
|
||||
max_seq_len=config.max_seq_len,
|
||||
phoneme_cache_path=config.phoneme_cache_path,
|
||||
use_phonemes=config.use_phonemes,
|
||||
phoneme_language=config.phoneme_language,
|
||||
enable_eos_bos=config.enable_eos_bos_chars,
|
||||
use_noise_augment=not is_eval,
|
||||
verbose=verbose,
|
||||
speaker_id_mapping=speaker_id_mapping,
|
||||
d_vector_mapping=d_vector_mapping
|
||||
if config.use_speaker_embedding and config.use_d_vector_file
|
||||
else None,
|
||||
)
|
||||
|
||||
if config.use_phonemes and config.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
dataset.compute_input_seq(config.num_loader_workers)
|
||||
dataset.sort_items()
|
||||
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=config.eval_batch_size if is_eval else config.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=config.num_eval_loader_workers if is_eval else config.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
def test_run(self) -> Tuple[Dict, Dict]:
|
||||
"""Generic test run for `tts` models used by `Trainer`.
|
||||
|
||||
You can override this for a different behaviour.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, Dict]: Test figures and audios to be projected to Tensorboard.
|
||||
"""
|
||||
print(" | > Synthesizing test sentences.")
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
test_sentences = self.config.test_sentences
|
||||
aux_inputs = self._get_aux_inputs()
|
||||
for idx, sen in enumerate(test_sentences):
|
||||
wav, alignment, model_outputs, _ = synthesis(
|
||||
self.model,
|
||||
sen,
|
||||
self.config,
|
||||
self.use_cuda,
|
||||
self.ap,
|
||||
speaker_id=aux_inputs["speaker_id"],
|
||||
d_vector=aux_inputs["d_vector"],
|
||||
style_wav=aux_inputs["style_wav"],
|
||||
enable_eos_bos_chars=self.config.enable_eos_bos_chars,
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False,
|
||||
).values()
|
||||
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(model_outputs, self.ap, output_fig=False)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
|
||||
return test_figures, test_audios
|
|
@ -4,132 +4,116 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.configs import GlowTTSConfig
|
||||
from TTS.tts.layers.glow_tts.decoder import Decoder
|
||||
from TTS.tts.layers.glow_tts.encoder import Encoder
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.speakers import get_speaker_manager
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
class GlowTTS(nn.Module):
|
||||
class GlowTTS(BaseTTS):
|
||||
"""Glow TTS models from https://arxiv.org/abs/2005.11129
|
||||
|
||||
Args:
|
||||
num_chars (int): number of embedding characters.
|
||||
hidden_channels_enc (int): number of embedding and encoder channels.
|
||||
hidden_channels_dec (int): number of decoder channels.
|
||||
use_encoder_prenet (bool): enable/disable prenet for encoder. Prenet modules are hard-coded for each alternative encoder.
|
||||
hidden_channels_dp (int): number of duration predictor channels.
|
||||
out_channels (int): number of output channels. It should be equal to the number of spectrogram filter.
|
||||
num_flow_blocks_dec (int): number of decoder blocks.
|
||||
kernel_size_dec (int): decoder kernel size.
|
||||
dilation_rate (int): rate to increase dilation by each layer in a decoder block.
|
||||
num_block_layers (int): number of decoder layers in each decoder block.
|
||||
dropout_p_dec (float): dropout rate for decoder.
|
||||
num_speaker (int): number of speaker to define the size of speaker embedding layer.
|
||||
c_in_channels (int): number of speaker embedding channels. It is set to 512 if embeddings are learned.
|
||||
num_splits (int): number of split levels in inversible conv1x1 operation.
|
||||
num_squeeze (int): number of squeeze levels. When squeezing channels increases and time steps reduces by the factor 'num_squeeze'.
|
||||
sigmoid_scale (bool): enable/disable sigmoid scaling in decoder.
|
||||
mean_only (bool): if True, encoder only computes mean value and uses constant variance for each time step.
|
||||
encoder_type (str): encoder module type.
|
||||
encoder_params (dict): encoder module parameters.
|
||||
speaker_embedding_dim (int): channels of external speaker embedding vectors.
|
||||
Paper abstract:
|
||||
Recently, text-to-speech (TTS) models such as FastSpeech and ParaNet have been proposed to generate
|
||||
mel-spectrograms from text in parallel. Despite the advantage, the parallel TTS models cannot be trained
|
||||
without guidance from autoregressive TTS models as their external aligners. In this work, we propose Glow-TTS,
|
||||
a flow-based generative model for parallel TTS that does not require any external aligner. By combining the
|
||||
properties of flows and dynamic programming, the proposed model searches for the most probable monotonic
|
||||
alignment between text and the latent representation of speech on its own. We demonstrate that enforcing hard
|
||||
monotonic alignments enables robust TTS, which generalizes to long utterances, and employing generative flows
|
||||
enables fast, diverse, and controllable speech synthesis. Glow-TTS obtains an order-of-magnitude speed-up over
|
||||
the autoregressive model, Tacotron 2, at synthesis with comparable speech quality. We further show that our
|
||||
model can be easily extended to a multi-speaker setting.
|
||||
|
||||
Check :class:`TTS.tts.configs.glow_tts_config.GlowTTSConfig` for class arguments.
|
||||
|
||||
Examples:
|
||||
>>> from TTS.tts.configs import GlowTTSConfig
|
||||
>>> from TTS.tts.models.glow_tts import GlowTTS
|
||||
>>> config = GlowTTSConfig()
|
||||
>>> model = GlowTTS(config)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
hidden_channels_enc,
|
||||
hidden_channels_dec,
|
||||
use_encoder_prenet,
|
||||
hidden_channels_dp,
|
||||
out_channels,
|
||||
num_flow_blocks_dec=12,
|
||||
inference_noise_scale=0.33,
|
||||
kernel_size_dec=5,
|
||||
dilation_rate=5,
|
||||
num_block_layers=4,
|
||||
dropout_p_dp=0.1,
|
||||
dropout_p_dec=0.05,
|
||||
num_speakers=0,
|
||||
c_in_channels=0,
|
||||
num_splits=4,
|
||||
num_squeeze=1,
|
||||
sigmoid_scale=False,
|
||||
mean_only=False,
|
||||
encoder_type="transformer",
|
||||
encoder_params=None,
|
||||
speaker_embedding_dim=None,
|
||||
):
|
||||
def __init__(self, config: GlowTTSConfig):
|
||||
|
||||
super().__init__()
|
||||
self.num_chars = num_chars
|
||||
self.hidden_channels_dp = hidden_channels_dp
|
||||
self.hidden_channels_enc = hidden_channels_enc
|
||||
self.hidden_channels_dec = hidden_channels_dec
|
||||
self.out_channels = out_channels
|
||||
self.num_flow_blocks_dec = num_flow_blocks_dec
|
||||
self.kernel_size_dec = kernel_size_dec
|
||||
self.dilation_rate = dilation_rate
|
||||
self.num_block_layers = num_block_layers
|
||||
self.dropout_p_dec = dropout_p_dec
|
||||
self.num_speakers = num_speakers
|
||||
self.c_in_channels = c_in_channels
|
||||
self.num_splits = num_splits
|
||||
self.num_squeeze = num_squeeze
|
||||
self.sigmoid_scale = sigmoid_scale
|
||||
self.mean_only = mean_only
|
||||
self.use_encoder_prenet = use_encoder_prenet
|
||||
self.inference_noise_scale = inference_noise_scale
|
||||
|
||||
# model constants.
|
||||
self.noise_scale = 0.33 # defines the noise variance applied to the random z vector at inference.
|
||||
self.length_scale = 1.0 # scaler for the duration predictor. The larger it is, the slower the speech.
|
||||
self.speaker_embedding_dim = speaker_embedding_dim
|
||||
# pass all config fields to `self`
|
||||
# for fewer code change
|
||||
self.config = config
|
||||
for key in config:
|
||||
setattr(self, key, config[key])
|
||||
|
||||
_, self.config, self.num_chars = self.get_characters(config)
|
||||
self.decoder_output_dim = config.out_channels
|
||||
|
||||
self.init_multispeaker(config)
|
||||
|
||||
# if is a multispeaker and c_in_channels is 0, set to 256
|
||||
if num_speakers > 1:
|
||||
if self.c_in_channels == 0 and not self.speaker_embedding_dim:
|
||||
self.c_in_channels = 0
|
||||
if self.num_speakers > 1:
|
||||
if self.d_vector_dim:
|
||||
self.c_in_channels = self.d_vector_dim
|
||||
elif self.c_in_channels == 0 and not self.d_vector_dim:
|
||||
# TODO: make this adjustable
|
||||
self.c_in_channels = 256
|
||||
elif self.speaker_embedding_dim:
|
||||
self.c_in_channels = self.speaker_embedding_dim
|
||||
|
||||
self.encoder = Encoder(
|
||||
num_chars,
|
||||
out_channels=out_channels,
|
||||
hidden_channels=hidden_channels_enc,
|
||||
hidden_channels_dp=hidden_channels_dp,
|
||||
encoder_type=encoder_type,
|
||||
encoder_params=encoder_params,
|
||||
mean_only=mean_only,
|
||||
use_prenet=use_encoder_prenet,
|
||||
dropout_p_dp=dropout_p_dp,
|
||||
self.num_chars,
|
||||
out_channels=self.out_channels,
|
||||
hidden_channels=self.hidden_channels_enc,
|
||||
hidden_channels_dp=self.hidden_channels_dp,
|
||||
encoder_type=self.encoder_type,
|
||||
encoder_params=self.encoder_params,
|
||||
mean_only=self.mean_only,
|
||||
use_prenet=self.use_encoder_prenet,
|
||||
dropout_p_dp=self.dropout_p_dp,
|
||||
c_in_channels=self.c_in_channels,
|
||||
)
|
||||
|
||||
self.decoder = Decoder(
|
||||
out_channels,
|
||||
hidden_channels_dec,
|
||||
kernel_size_dec,
|
||||
dilation_rate,
|
||||
num_flow_blocks_dec,
|
||||
num_block_layers,
|
||||
dropout_p=dropout_p_dec,
|
||||
num_splits=num_splits,
|
||||
num_squeeze=num_squeeze,
|
||||
sigmoid_scale=sigmoid_scale,
|
||||
self.out_channels,
|
||||
self.hidden_channels_dec,
|
||||
self.kernel_size_dec,
|
||||
self.dilation_rate,
|
||||
self.num_flow_blocks_dec,
|
||||
self.num_block_layers,
|
||||
dropout_p=self.dropout_p_dec,
|
||||
num_splits=self.num_splits,
|
||||
num_squeeze=self.num_squeeze,
|
||||
sigmoid_scale=self.sigmoid_scale,
|
||||
c_in_channels=self.c_in_channels,
|
||||
)
|
||||
|
||||
if num_speakers > 1 and not speaker_embedding_dim:
|
||||
# speaker embedding layer
|
||||
self.emb_g = nn.Embedding(num_speakers, self.c_in_channels)
|
||||
def init_multispeaker(self, config: "Coqpit", data: list = None) -> None:
|
||||
"""Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
|
||||
or with external `d_vectors` computed from a speaker encoder model.
|
||||
|
||||
If you need a different behaviour, override this function for your model.
|
||||
|
||||
Args:
|
||||
config (Coqpit): Model configuration.
|
||||
data (List, optional): Dataset items to infer number of speakers. Defaults to None.
|
||||
"""
|
||||
# init speaker manager
|
||||
self.speaker_manager = get_speaker_manager(config, data=data)
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
# init speaker embedding layer
|
||||
if config.use_speaker_embedding and not config.use_d_vector_file:
|
||||
self.embedded_speaker_dim = self.c_in_channels
|
||||
self.emb_g = nn.Embedding(self.num_speakers, self.embedded_speaker_dim)
|
||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||
|
||||
@staticmethod
|
||||
def compute_outputs(attn, o_mean, o_log_scale, x_mask):
|
||||
# compute final values with the computed alignment
|
||||
""" Compute and format the mode outputs with the given alignment map"""
|
||||
y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
|
||||
1, 2
|
||||
) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
|
@ -140,19 +124,23 @@ class GlowTTS(nn.Module):
|
|||
o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask
|
||||
return y_mean, y_log_scale, o_attn_dur
|
||||
|
||||
def forward(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None):
|
||||
def forward(
|
||||
self, x, x_lengths, y, y_lengths=None, aux_input={"d_vectors": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, T]
|
||||
x_lenghts: B
|
||||
y: [B, C, T]
|
||||
y_lengths: B
|
||||
g: [B, C] or B
|
||||
- x: :math:`[B, T]`
|
||||
- x_lenghts::math:` B`
|
||||
- y: :math:`[B, T, C]`
|
||||
- y_lengths::math:` B`
|
||||
- g: :math:`[B, C] or B`
|
||||
"""
|
||||
y = y.transpose(1, 2)
|
||||
y_max_length = y.size(2)
|
||||
# norm speaker embeddings
|
||||
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
||||
if g is not None:
|
||||
if self.speaker_embedding_dim:
|
||||
if self.d_vector_dim:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
else:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||
|
@ -177,24 +165,38 @@ class GlowTTS(nn.Module):
|
|||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||
return z, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur
|
||||
outputs = {
|
||||
"model_outputs": z.transpose(1, 2),
|
||||
"logdet": logdet,
|
||||
"y_mean": y_mean.transpose(1, 2),
|
||||
"y_log_scale": y_log_scale.transpose(1, 2),
|
||||
"alignments": attn,
|
||||
"durations_log": o_dur_log.transpose(1, 2),
|
||||
"total_durations_log": o_attn_dur.transpose(1, 2),
|
||||
}
|
||||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference_with_MAS(self, x, x_lengths, y=None, y_lengths=None, attn=None, g=None):
|
||||
def inference_with_MAS(
|
||||
self, x, x_lengths, y=None, y_lengths=None, aux_input={"d_vectors": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
It's similar to the teacher forcing in Tacotron.
|
||||
It was proposed in: https://arxiv.org/abs/2104.05557
|
||||
|
||||
Shapes:
|
||||
x: [B, T]
|
||||
x_lenghts: B
|
||||
y: [B, C, T]
|
||||
y_lengths: B
|
||||
g: [B, C] or B
|
||||
- x: :math:`[B, T]`
|
||||
- x_lenghts: :math:`B`
|
||||
- y: :math:`[B, T, C]`
|
||||
- y_lengths: :math:`B`
|
||||
- g: :math:`[B, C] or B`
|
||||
"""
|
||||
y = y.transpose(1, 2)
|
||||
y_max_length = y.size(2)
|
||||
# norm speaker embeddings
|
||||
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
||||
if g is not None:
|
||||
if self.external_speaker_embedding_dim:
|
||||
if self.external_d_vector_dim:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
else:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||
|
@ -225,21 +227,33 @@ class GlowTTS(nn.Module):
|
|||
|
||||
# reverse the decoder and predict using the aligned distribution
|
||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||
|
||||
return y, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur
|
||||
outputs = {
|
||||
"model_outputs": z.transpose(1, 2),
|
||||
"logdet": logdet,
|
||||
"y_mean": y_mean.transpose(1, 2),
|
||||
"y_log_scale": y_log_scale.transpose(1, 2),
|
||||
"alignments": attn,
|
||||
"durations_log": o_dur_log.transpose(1, 2),
|
||||
"total_durations_log": o_attn_dur.transpose(1, 2),
|
||||
}
|
||||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def decoder_inference(self, y, y_lengths=None, g=None):
|
||||
def decoder_inference(
|
||||
self, y, y_lengths=None, aux_input={"d_vectors": None}
|
||||
): # pylint: disable=dangerous-default-value
|
||||
"""
|
||||
Shapes:
|
||||
y: [B, C, T]
|
||||
y_lengths: B
|
||||
g: [B, C] or B
|
||||
- y: :math:`[B, T, C]`
|
||||
- y_lengths: :math:`B`
|
||||
- g: :math:`[B, C] or B`
|
||||
"""
|
||||
y = y.transpose(1, 2)
|
||||
y_max_length = y.size(2)
|
||||
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
||||
# norm speaker embeddings
|
||||
if g is not None:
|
||||
if self.external_speaker_embedding_dim:
|
||||
if self.external_d_vector_dim:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
else:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||
|
@ -252,12 +266,18 @@ class GlowTTS(nn.Module):
|
|||
# reverse decoder and predict
|
||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||
|
||||
return y, logdet
|
||||
outputs = {}
|
||||
outputs["model_outputs"] = y.transpose(1, 2)
|
||||
outputs["logdet"] = logdet
|
||||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, x, x_lengths, g=None):
|
||||
def inference(self, x, aux_input={"x_lengths": None, "d_vectors": None}): # pylint: disable=dangerous-default-value
|
||||
x_lengths = aux_input["x_lengths"]
|
||||
g = aux_input["d_vectors"] if aux_input is not None and "d_vectors" in aux_input else None
|
||||
|
||||
if g is not None:
|
||||
if self.speaker_embedding_dim:
|
||||
if self.d_vector_dim:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
else:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
|
||||
|
@ -280,7 +300,72 @@ class GlowTTS(nn.Module):
|
|||
# decoder pass
|
||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||
return y, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur
|
||||
outputs = {
|
||||
"model_outputs": y.transpose(1, 2),
|
||||
"logdet": logdet,
|
||||
"y_mean": y_mean.transpose(1, 2),
|
||||
"y_log_scale": y_log_scale.transpose(1, 2),
|
||||
"alignments": attn,
|
||||
"durations_log": o_dur_log.transpose(1, 2),
|
||||
"total_durations_log": o_attn_dur.transpose(1, 2),
|
||||
}
|
||||
return outputs
|
||||
|
||||
def train_step(self, batch: dict, criterion: nn.Module):
|
||||
"""Perform a single training step by fetching the right set if samples from the batch.
|
||||
|
||||
Args:
|
||||
batch (dict): [description]
|
||||
criterion (nn.Module): [description]
|
||||
"""
|
||||
text_input = batch["text_input"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
mel_input = batch["mel_input"]
|
||||
mel_lengths = batch["mel_lengths"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input={"d_vectors": d_vectors})
|
||||
|
||||
loss_dict = criterion(
|
||||
outputs["model_outputs"],
|
||||
outputs["y_mean"],
|
||||
outputs["y_log_scale"],
|
||||
outputs["logdet"],
|
||||
mel_lengths,
|
||||
outputs["durations_log"],
|
||||
outputs["total_durations_log"],
|
||||
text_lengths,
|
||||
)
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True)
|
||||
loss_dict["align_error"] = align_error
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
|
||||
model_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
mel_input = batch["mel_input"]
|
||||
|
||||
pred_spec = model_outputs[0].data.cpu().numpy()
|
||||
gt_spec = mel_input[0].data.cpu().numpy()
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
|
||||
figures = {
|
||||
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||
"alignment": plot_alignment(align_img, output_fig=False),
|
||||
}
|
||||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
return figures, {"audio": train_audio}
|
||||
|
||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||
return self.train_log(ap, batch, outputs)
|
||||
|
||||
def preprocess(self, y, y_lengths, y_max_length, attn=None):
|
||||
if y_max_length is not None:
|
||||
|
@ -303,3 +388,8 @@ class GlowTTS(nn.Module):
|
|||
self.eval()
|
||||
self.store_inverse()
|
||||
assert not self.training
|
||||
|
||||
def get_criterion(self):
|
||||
from TTS.tts.layers.losses import GlowTTSLoss # pylint: disable=import-outside-toplevel
|
||||
|
||||
return GlowTTSLoss()
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.feed_forward.decoder import Decoder
|
||||
|
@ -6,21 +9,16 @@ from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
|
|||
from TTS.tts.layers.feed_forward.encoder import Encoder
|
||||
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.models.base_tts import BaseTTS
|
||||
from TTS.tts.utils.data import sequence_mask
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
class SpeedySpeech(nn.Module):
|
||||
"""Speedy Speech model
|
||||
https://arxiv.org/abs/2008.03802
|
||||
|
||||
Encoder -> DurationPredictor -> Decoder
|
||||
|
||||
This model is able to achieve a reasonable performance with only
|
||||
~3M model parameters and convolutional layers.
|
||||
|
||||
This model requires precomputed phoneme durations to train a duration predictor. At inference
|
||||
it only uses the duration predictor to compute durations and expand encoder outputs respectively.
|
||||
|
||||
@dataclass
|
||||
class SpeedySpeechArgs(Coqpit):
|
||||
"""
|
||||
Args:
|
||||
num_chars (int): number of unique input to characters
|
||||
out_channels (int): number of output tensor channels. It is equal to the expected spectrogram size.
|
||||
|
@ -32,49 +30,106 @@ class SpeedySpeech(nn.Module):
|
|||
decoder_type (str, optional): decoder type. Defaults to 'residual_conv_bn'.
|
||||
decoder_params (dict, optional): set decoder parameters depending on 'decoder_type'. Defaults to { "kernel_size": 4, "dilations": 4 * [1, 2, 4, 8] + [1], "num_conv_blocks": 2, "num_res_blocks": 17 }.
|
||||
num_speakers (int, optional): number of speakers for multi-speaker training. Defaults to 0.
|
||||
external_c (bool, optional): enable external speaker embeddings. Defaults to False.
|
||||
c_in_channels (int, optional): number of channels in speaker embedding vectors. Defaults to 0.
|
||||
use_d_vector (bool, optional): enable external speaker embeddings. Defaults to False.
|
||||
d_vector_dim (int, optional): number of channels in speaker embedding vectors. Defaults to 0.
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
positional_encoding=True,
|
||||
length_scale=1,
|
||||
encoder_type="residual_conv_bn",
|
||||
encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13},
|
||||
decoder_type="residual_conv_bn",
|
||||
decoder_params={
|
||||
num_chars: int = None
|
||||
out_channels: int = 80
|
||||
hidden_channels: int = 128
|
||||
num_speakers: int = 0
|
||||
positional_encoding: bool = True
|
||||
length_scale: int = 1
|
||||
encoder_type: str = "residual_conv_bn"
|
||||
encoder_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 13,
|
||||
}
|
||||
)
|
||||
decoder_type: str = "residual_conv_bn"
|
||||
decoder_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4, 8] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 17,
|
||||
},
|
||||
num_speakers=0,
|
||||
external_c=False,
|
||||
c_in_channels=0,
|
||||
):
|
||||
}
|
||||
)
|
||||
use_d_vector: bool = False
|
||||
d_vector_dim: int = 0
|
||||
|
||||
|
||||
class SpeedySpeech(BaseTTS):
|
||||
"""Speedy Speech model
|
||||
https://arxiv.org/abs/2008.03802
|
||||
|
||||
Encoder -> DurationPredictor -> Decoder
|
||||
|
||||
Paper abstract:
|
||||
While recent neural sequence-to-sequence models have greatly improved the quality of speech
|
||||
synthesis, there has not been a system capable of fast training, fast inference and high-quality audio synthesis
|
||||
at the same time. We propose a student-teacher network capable of high-quality faster-than-real-time spectrogram
|
||||
synthesis, with low requirements on computational resources and fast training time. We show that self-attention
|
||||
layers are not necessary for generation of high quality audio. We utilize simple convolutional blocks with
|
||||
residual connections in both student and teacher networks and use only a single attention layer in the teacher
|
||||
model. Coupled with a MelGAN vocoder, our model's voice quality was rated significantly higher than Tacotron 2.
|
||||
Our model can be efficiently trained on a single GPU and can run in real time even on a CPU. We provide both
|
||||
our source code and audio samples in our GitHub repository.
|
||||
|
||||
Notes:
|
||||
The vanilla model is able to achieve a reasonable performance with only
|
||||
~3M model parameters and convolutional layers.
|
||||
|
||||
This model requires precomputed phoneme durations to train a duration predictor. At inference
|
||||
it only uses the duration predictor to compute durations and expand encoder outputs respectively.
|
||||
|
||||
You can also mix and match different encoder and decoder networks beyond the paper.
|
||||
|
||||
Check `SpeedySpeechArgs` for arguments.
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
||||
def __init__(self, config: Coqpit):
|
||||
super().__init__()
|
||||
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
|
||||
self.emb = nn.Embedding(num_chars, hidden_channels)
|
||||
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels)
|
||||
if positional_encoding:
|
||||
self.pos_encoder = PositionalEncoding(hidden_channels)
|
||||
self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params)
|
||||
self.duration_predictor = DurationPredictor(hidden_channels + c_in_channels)
|
||||
self.config = config
|
||||
|
||||
if num_speakers > 1 and not external_c:
|
||||
if "characters" in config:
|
||||
_, self.config, self.num_chars = self.get_characters(config)
|
||||
|
||||
self.length_scale = (
|
||||
float(config.model_args.length_scale)
|
||||
if isinstance(config.model_args.length_scale, int)
|
||||
else config.model_args.length_scale
|
||||
)
|
||||
self.emb = nn.Embedding(config.model_args.num_chars, config.model_args.hidden_channels)
|
||||
self.encoder = Encoder(
|
||||
config.model_args.hidden_channels,
|
||||
config.model_args.hidden_channels,
|
||||
config.model_args.encoder_type,
|
||||
config.model_args.encoder_params,
|
||||
config.model_args.d_vector_dim,
|
||||
)
|
||||
if config.model_args.positional_encoding:
|
||||
self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels)
|
||||
self.decoder = Decoder(
|
||||
config.model_args.out_channels,
|
||||
config.model_args.hidden_channels,
|
||||
config.model_args.decoder_type,
|
||||
config.model_args.decoder_params,
|
||||
)
|
||||
self.duration_predictor = DurationPredictor(config.model_args.hidden_channels + config.model_args.d_vector_dim)
|
||||
|
||||
if config.model_args.num_speakers > 1 and not config.model_args.use_d_vector:
|
||||
# speaker embedding layer
|
||||
self.emb_g = nn.Embedding(num_speakers, c_in_channels)
|
||||
self.emb_g = nn.Embedding(config.model_args.num_speakers, config.model_args.d_vector_dim)
|
||||
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
||||
|
||||
if c_in_channels > 0 and c_in_channels != hidden_channels:
|
||||
self.proj_g = nn.Conv1d(c_in_channels, hidden_channels, 1)
|
||||
if config.model_args.d_vector_dim > 0 and config.model_args.d_vector_dim != config.model_args.hidden_channels:
|
||||
self.proj_g = nn.Conv1d(config.model_args.d_vector_dim, config.model_args.hidden_channels, 1)
|
||||
|
||||
@staticmethod
|
||||
def expand_encoder_outputs(en, dr, x_mask, y_mask):
|
||||
|
@ -153,8 +208,11 @@ class SpeedySpeech(nn.Module):
|
|||
o_de = self.decoder(o_en_ex, y_mask, g=g)
|
||||
return o_de, attn.transpose(1, 2)
|
||||
|
||||
def forward(self, x, x_lengths, y_lengths, dr, g=None): # pylint: disable=unused-argument
|
||||
def forward(
|
||||
self, x, x_lengths, y_lengths, dr, aux_input={"d_vectors": None, "speaker_ids": None}
|
||||
): # pylint: disable=unused-argument
|
||||
"""
|
||||
TODO: speaker embedding for speaker_ids
|
||||
Shapes:
|
||||
x: [B, T_max]
|
||||
x_lengths: [B]
|
||||
|
@ -162,18 +220,22 @@ class SpeedySpeech(nn.Module):
|
|||
dr: [B, T_max]
|
||||
g: [B, C]
|
||||
"""
|
||||
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g)
|
||||
return o_de, o_dr_log.squeeze(1), attn
|
||||
outputs = {"model_outputs": o_de.transpose(1, 2), "durations_log": o_dr_log.squeeze(1), "alignments": attn}
|
||||
return outputs
|
||||
|
||||
def inference(self, x, x_lengths, g=None): # pylint: disable=unused-argument
|
||||
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
|
||||
"""
|
||||
Shapes:
|
||||
x: [B, T_max]
|
||||
x_lengths: [B]
|
||||
g: [B, C]
|
||||
"""
|
||||
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
|
||||
# input sequence should be greated than the max convolution size
|
||||
inference_padding = 5
|
||||
if x.shape[1] < 13:
|
||||
|
@ -186,7 +248,60 @@ class SpeedySpeech(nn.Module):
|
|||
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||
y_lengths = o_dr.sum(1)
|
||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
|
||||
return o_de, attn
|
||||
outputs = {"model_outputs": o_de.transpose(1, 2), "alignments": attn, "durations_log": None}
|
||||
return outputs
|
||||
|
||||
def train_step(self, batch: dict, criterion: nn.Module):
|
||||
text_input = batch["text_input"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
mel_input = batch["mel_input"]
|
||||
mel_lengths = batch["mel_lengths"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
durations = batch["durations"]
|
||||
|
||||
aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
|
||||
outputs = self.forward(text_input, text_lengths, mel_lengths, durations, aux_input)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(
|
||||
outputs["model_outputs"],
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
outputs["durations_log"],
|
||||
torch.log(1 + durations),
|
||||
text_lengths,
|
||||
)
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(outputs["alignments"], binary=True)
|
||||
loss_dict["align_error"] = align_error
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
|
||||
model_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
mel_input = batch["mel_input"]
|
||||
|
||||
pred_spec = model_outputs[0].data.cpu().numpy()
|
||||
gt_spec = mel_input[0].data.cpu().numpy()
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
|
||||
figures = {
|
||||
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||
"alignment": plot_alignment(align_img, output_fig=False),
|
||||
}
|
||||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
return figures, {"audio": train_audio}
|
||||
|
||||
def eval_step(self, batch: dict, criterion: nn.Module):
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
|
||||
return self.train_log(ap, batch, outputs)
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
|
@ -196,3 +311,8 @@ class SpeedySpeech(nn.Module):
|
|||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
def get_criterion(self):
|
||||
from TTS.tts.layers.losses import SpeedySpeechLoss # pylint: disable=import-outside-toplevel
|
||||
|
||||
return SpeedySpeechLoss(self.config)
|
||||
|
|
|
@ -1,157 +1,86 @@
|
|||
# coding: utf-8
|
||||
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||
from TTS.tts.layers.tacotron.tacotron import Decoder, Encoder, PostCBHG
|
||||
from TTS.tts.models.tacotron_abstract import TacotronAbstract
|
||||
from TTS.tts.models.base_tacotron import BaseTacotron
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
class Tacotron(TacotronAbstract):
|
||||
class Tacotron(BaseTacotron):
|
||||
"""Tacotron as in https://arxiv.org/abs/1703.10135
|
||||
|
||||
It's an autoregressive encoder-attention-decoder-postnet architecture.
|
||||
|
||||
Args:
|
||||
num_chars (int): number of input characters to define the size of embedding layer.
|
||||
num_speakers (int): number of speakers in the dataset. >1 enables multi-speaker training and model learns speaker embeddings.
|
||||
r (int): initial model reduction rate.
|
||||
postnet_output_dim (int, optional): postnet output channels. Defaults to 80.
|
||||
decoder_output_dim (int, optional): decoder output channels. Defaults to 80.
|
||||
attn_type (str, optional): attention type. Check ```TTS.tts.layers.attentions.init_attn```. Defaults to 'original'.
|
||||
attn_win (bool, optional): enable/disable attention windowing.
|
||||
It especially useful at inference to keep attention alignment diagonal. Defaults to False.
|
||||
attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax".
|
||||
prenet_type (str, optional): prenet type for the decoder. Defaults to "original".
|
||||
prenet_dropout (bool, optional): prenet dropout rate. Defaults to True.
|
||||
prenet_dropout_at_inference (bool, optional): use dropout at inference time. This leads to a better quality for
|
||||
some models. Defaults to False.
|
||||
forward_attn (bool, optional): enable/disable forward attention.
|
||||
It is only valid if ```attn_type``` is ```original```. Defaults to False.
|
||||
trans_agent (bool, optional): enable/disable transition agent in forward attention. Defaults to False.
|
||||
forward_attn_mask (bool, optional): enable/disable extra masking over forward attention. Defaults to False.
|
||||
location_attn (bool, optional): enable/disable location sensitive attention.
|
||||
It is only valid if ```attn_type``` is ```original```. Defaults to True.
|
||||
attn_K (int, optional): Number of attention heads for GMM attention. Defaults to 5.
|
||||
separate_stopnet (bool, optional): enable/disable separate stopnet training without only gradient
|
||||
flow from stopnet to the rest of the model. Defaults to True.
|
||||
bidirectional_decoder (bool, optional): enable/disable bidirectional decoding. Defaults to False.
|
||||
double_decoder_consistency (bool, optional): enable/disable double decoder consistency. Defaults to False.
|
||||
ddc_r (int, optional): reduction rate for the coarse decoder of double decoder consistency. Defaults to None.
|
||||
encoder_in_features (int, optional): input channels for the encoder. Defaults to 512.
|
||||
decoder_in_features (int, optional): input channels for the decoder. Defaults to 512.
|
||||
speaker_embedding_dim (int, optional): external speaker conditioning vector channels. Defaults to None.
|
||||
use_gst (bool, optional): enable/disable Global style token module.
|
||||
gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None.
|
||||
memory_size (int, optional): size of the history queue fed to the prenet. Model feeds the last ```memory_size```
|
||||
output frames to the prenet.
|
||||
Check `TacotronConfig` for the arguments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r=5,
|
||||
postnet_output_dim=1025,
|
||||
decoder_output_dim=80,
|
||||
attn_type="original",
|
||||
attn_win=False,
|
||||
attn_norm="sigmoid",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
prenet_dropout_at_inference=False,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
encoder_in_features=256,
|
||||
decoder_in_features=256,
|
||||
speaker_embedding_dim=None,
|
||||
use_gst=False,
|
||||
gst=None,
|
||||
memory_size=5,
|
||||
):
|
||||
super().__init__(
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim,
|
||||
decoder_output_dim,
|
||||
attn_type,
|
||||
attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
prenet_dropout_at_inference,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
bidirectional_decoder,
|
||||
double_decoder_consistency,
|
||||
ddc_r,
|
||||
encoder_in_features,
|
||||
decoder_in_features,
|
||||
speaker_embedding_dim,
|
||||
use_gst,
|
||||
gst,
|
||||
)
|
||||
def __init__(self, config: Coqpit):
|
||||
super().__init__(config)
|
||||
|
||||
# speaker embedding layers
|
||||
self.num_chars, self.config = self.get_characters(config)
|
||||
|
||||
# pass all config fields to `self`
|
||||
# for fewer code change
|
||||
for key in config:
|
||||
setattr(self, key, config[key])
|
||||
|
||||
# speaker embedding layer
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
speaker_embedding_dim = 256
|
||||
self.speaker_embedding = nn.Embedding(self.num_speakers, speaker_embedding_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
self.init_multispeaker(config)
|
||||
|
||||
# speaker and gst embeddings is concat in decoder input
|
||||
if self.num_speakers > 1:
|
||||
self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim
|
||||
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
|
||||
|
||||
if self.use_gst:
|
||||
self.decoder_in_features += self.gst.gst_embedding_dim
|
||||
|
||||
# embedding layer
|
||||
self.embedding = nn.Embedding(num_chars, 256, padding_idx=0)
|
||||
self.embedding = nn.Embedding(self.num_chars, 256, padding_idx=0)
|
||||
self.embedding.weight.data.normal_(0, 0.3)
|
||||
|
||||
# base model layers
|
||||
self.encoder = Encoder(self.encoder_in_features)
|
||||
self.decoder = Decoder(
|
||||
self.decoder_in_features,
|
||||
decoder_output_dim,
|
||||
r,
|
||||
memory_size,
|
||||
attn_type,
|
||||
attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
self.decoder_output_dim,
|
||||
self.r,
|
||||
self.memory_size,
|
||||
self.attention_type,
|
||||
self.windowing,
|
||||
self.attention_norm,
|
||||
self.prenet_type,
|
||||
self.prenet_dropout,
|
||||
self.use_forward_attn,
|
||||
self.transition_agent,
|
||||
self.forward_attn_mask,
|
||||
self.location_attn,
|
||||
self.attention_heads,
|
||||
self.separate_stopnet,
|
||||
self.max_decoder_steps,
|
||||
)
|
||||
self.postnet = PostCBHG(decoder_output_dim)
|
||||
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim)
|
||||
self.postnet = PostCBHG(self.decoder_output_dim)
|
||||
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, self.out_channels)
|
||||
|
||||
# setup prenet dropout
|
||||
self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference
|
||||
self.decoder.prenet.dropout_at_inference = self.prenet_dropout_at_inference
|
||||
|
||||
# global style token layers
|
||||
if self.gst and self.use_gst:
|
||||
self.gst_layer = GST(
|
||||
num_mel=decoder_output_dim,
|
||||
speaker_embedding_dim=speaker_embedding_dim,
|
||||
num_heads=gst.gst_num_heads,
|
||||
num_style_tokens=gst.gst_num_style_tokens,
|
||||
gst_embedding_dim=gst.gst_embedding_dim,
|
||||
num_mel=self.decoder_output_dim,
|
||||
d_vector_dim=self.d_vector_dim
|
||||
if self.config.gst.gst_use_speaker_embedding and self.use_speaker_embedding
|
||||
else None,
|
||||
num_heads=self.gst.gst_num_heads,
|
||||
num_style_tokens=self.gst.gst_num_style_tokens,
|
||||
gst_embedding_dim=self.gst.gst_embedding_dim,
|
||||
)
|
||||
# backward pass decoder
|
||||
if self.bidirectional_decoder:
|
||||
|
@ -160,35 +89,35 @@ class Tacotron(TacotronAbstract):
|
|||
if self.double_decoder_consistency:
|
||||
self.coarse_decoder = Decoder(
|
||||
self.decoder_in_features,
|
||||
decoder_output_dim,
|
||||
ddc_r,
|
||||
memory_size,
|
||||
attn_type,
|
||||
attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
self.decoder_output_dim,
|
||||
self.ddc_r,
|
||||
self.memory_size,
|
||||
self.attention_type,
|
||||
self.windowing,
|
||||
self.attention_norm,
|
||||
self.prenet_type,
|
||||
self.prenet_dropout,
|
||||
self.use_forward_attn,
|
||||
self.transition_agent,
|
||||
self.forward_attn_mask,
|
||||
self.location_attn,
|
||||
self.attention_heads,
|
||||
self.separate_stopnet,
|
||||
self.max_decoder_steps,
|
||||
)
|
||||
|
||||
def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None, speaker_embeddings=None):
|
||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input=None):
|
||||
"""
|
||||
Shapes:
|
||||
characters: [B, T_in]
|
||||
text: [B, T_in]
|
||||
text_lengths: [B]
|
||||
mel_specs: [B, T_out, C]
|
||||
mel_lengths: [B]
|
||||
speaker_ids: [B, 1]
|
||||
speaker_embeddings: [B, C]
|
||||
aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C]
|
||||
"""
|
||||
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
|
||||
inputs = self.embedding(text)
|
||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||
# B x T_in x embed_dim
|
||||
inputs = self.embedding(characters)
|
||||
# B x T_in x encoder_in_features
|
||||
encoder_outputs = self.encoder(inputs)
|
||||
# sequence masking
|
||||
|
@ -196,16 +125,18 @@ class Tacotron(TacotronAbstract):
|
|||
# global style token
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, speaker_embeddings)
|
||||
encoder_outputs = self.compute_gst(
|
||||
encoder_outputs, mel_specs, aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||
)
|
||||
# speaker embedding
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
if not self.use_d_vectors:
|
||||
# B x 1 x speaker_embed_dim
|
||||
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None]
|
||||
else:
|
||||
# B x 1 x speaker_embed_dim
|
||||
speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1)
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||
embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1)
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers)
|
||||
# decoder_outputs: B x decoder_in_features x T_out
|
||||
# alignments: B x T_in x encoder_in_features
|
||||
# stop_tokens: B x T_in
|
||||
|
@ -224,45 +155,139 @@ class Tacotron(TacotronAbstract):
|
|||
decoder_outputs = decoder_outputs.transpose(1, 2).contiguous()
|
||||
if self.bidirectional_decoder:
|
||||
decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
|
||||
return (
|
||||
decoder_outputs,
|
||||
postnet_outputs,
|
||||
alignments,
|
||||
stop_tokens,
|
||||
decoder_outputs_backward,
|
||||
alignments_backward,
|
||||
)
|
||||
outputs["alignments_backward"] = alignments_backward
|
||||
outputs["decoder_outputs_backward"] = decoder_outputs_backward
|
||||
if self.double_decoder_consistency:
|
||||
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
|
||||
mel_specs, encoder_outputs, alignments, input_mask
|
||||
)
|
||||
return (
|
||||
decoder_outputs,
|
||||
postnet_outputs,
|
||||
alignments,
|
||||
stop_tokens,
|
||||
decoder_outputs_backward,
|
||||
alignments_backward,
|
||||
)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
outputs["alignments_backward"] = alignments_backward
|
||||
outputs["decoder_outputs_backward"] = decoder_outputs_backward
|
||||
outputs.update(
|
||||
{
|
||||
"model_outputs": postnet_outputs,
|
||||
"decoder_outputs": decoder_outputs,
|
||||
"alignments": alignments,
|
||||
"stop_tokens": stop_tokens,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, characters, speaker_ids=None, style_mel=None, speaker_embeddings=None):
|
||||
inputs = self.embedding(characters)
|
||||
def inference(self, text_input, aux_input=None):
|
||||
aux_input = self._format_aux_input(aux_input)
|
||||
inputs = self.embedding(text_input)
|
||||
encoder_outputs = self.encoder(inputs)
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings)
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
if not self.use_d_vectors:
|
||||
# B x 1 x speaker_embed_dim
|
||||
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])
|
||||
# reshape embedded_speakers
|
||||
if embedded_speakers.ndim == 1:
|
||||
embedded_speakers = embedded_speakers[None, None, :]
|
||||
elif embedded_speakers.ndim == 2:
|
||||
embedded_speakers = embedded_speakers[None, :]
|
||||
else:
|
||||
# B x 1 x speaker_embed_dim
|
||||
speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1)
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||
embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1)
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers)
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs)
|
||||
postnet_outputs = self.postnet(decoder_outputs)
|
||||
postnet_outputs = self.last_linear(postnet_outputs)
|
||||
decoder_outputs = decoder_outputs.transpose(1, 2)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
outputs = {
|
||||
"model_outputs": postnet_outputs,
|
||||
"decoder_outputs": decoder_outputs,
|
||||
"alignments": alignments,
|
||||
"stop_tokens": stop_tokens,
|
||||
}
|
||||
return outputs
|
||||
|
||||
def train_step(self, batch, criterion):
|
||||
"""Perform a single training step by fetching the right set if samples from the batch.
|
||||
|
||||
Args:
|
||||
batch ([type]): [description]
|
||||
criterion ([type]): [description]
|
||||
"""
|
||||
text_input = batch["text_input"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
mel_input = batch["mel_input"]
|
||||
mel_lengths = batch["mel_lengths"]
|
||||
linear_input = batch["linear_input"]
|
||||
stop_targets = batch["stop_targets"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
|
||||
# forward pass model
|
||||
outputs = self.forward(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
aux_input={"speaker_ids": speaker_ids, "d_vectors": d_vectors},
|
||||
)
|
||||
|
||||
# set the [alignment] lengths wrt reduction factor for guided attention
|
||||
if mel_lengths.max() % self.decoder.r != 0:
|
||||
alignment_lengths = (
|
||||
mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r))
|
||||
) // self.decoder.r
|
||||
else:
|
||||
alignment_lengths = mel_lengths // self.decoder.r
|
||||
|
||||
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(
|
||||
outputs["model_outputs"],
|
||||
outputs["decoder_outputs"],
|
||||
mel_input,
|
||||
linear_input,
|
||||
outputs["stop_tokens"],
|
||||
stop_targets,
|
||||
mel_lengths,
|
||||
outputs["decoder_outputs_backward"],
|
||||
outputs["alignments"],
|
||||
alignment_lengths,
|
||||
outputs["alignments_backward"],
|
||||
text_lengths,
|
||||
)
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(outputs["alignments"])
|
||||
loss_dict["align_error"] = align_error
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict) -> Tuple[Dict, Dict]:
|
||||
postnet_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
alignments_backward = outputs["alignments_backward"]
|
||||
mel_input = batch["mel_input"]
|
||||
|
||||
pred_spec = postnet_outputs[0].data.cpu().numpy()
|
||||
gt_spec = mel_input[0].data.cpu().numpy()
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
|
||||
figures = {
|
||||
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||
"alignment": plot_alignment(align_img, output_fig=False),
|
||||
}
|
||||
|
||||
if self.bidirectional_decoder or self.double_decoder_consistency:
|
||||
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
|
||||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_spectrogram(pred_spec.T)
|
||||
return figures, {"audio": train_audio}
|
||||
|
||||
def eval_step(self, batch, criterion):
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, ap, batch, outputs):
|
||||
return self.train_log(ap, batch, outputs)
|
||||
|
|
|
@ -1,151 +1,84 @@
|
|||
# coding: utf-8
|
||||
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||
from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
|
||||
from TTS.tts.models.tacotron_abstract import TacotronAbstract
|
||||
from TTS.tts.models.base_tacotron import BaseTacotron
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
# TODO: match function arguments with tacotron
|
||||
class Tacotron2(TacotronAbstract):
|
||||
class Tacotron2(BaseTacotron):
|
||||
"""Tacotron2 as in https://arxiv.org/abs/1712.05884
|
||||
|
||||
It's an autoregressive encoder-attention-decoder-postnet architecture.
|
||||
|
||||
Args:
|
||||
num_chars (int): number of input characters to define the size of embedding layer.
|
||||
num_speakers (int): number of speakers in the dataset. >1 enables multi-speaker training and model learns speaker embeddings.
|
||||
r (int): initial model reduction rate.
|
||||
postnet_output_dim (int, optional): postnet output channels. Defaults to 80.
|
||||
decoder_output_dim (int, optional): decoder output channels. Defaults to 80.
|
||||
attn_type (str, optional): attention type. Check ```TTS.tts.layers.tacotron.common_layers.init_attn```. Defaults to 'original'.
|
||||
attn_win (bool, optional): enable/disable attention windowing.
|
||||
It especially useful at inference to keep attention alignment diagonal. Defaults to False.
|
||||
attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax".
|
||||
prenet_type (str, optional): prenet type for the decoder. Defaults to "original".
|
||||
prenet_dropout (bool, optional): prenet dropout rate. Defaults to True.
|
||||
prenet_dropout_at_inference (bool, optional): use dropout at inference time. This leads to a better quality for
|
||||
some models. Defaults to False.
|
||||
forward_attn (bool, optional): enable/disable forward attention.
|
||||
It is only valid if ```attn_type``` is ```original```. Defaults to False.
|
||||
trans_agent (bool, optional): enable/disable transition agent in forward attention. Defaults to False.
|
||||
forward_attn_mask (bool, optional): enable/disable extra masking over forward attention. Defaults to False.
|
||||
location_attn (bool, optional): enable/disable location sensitive attention.
|
||||
It is only valid if ```attn_type``` is ```original```. Defaults to True.
|
||||
attn_K (int, optional): Number of attention heads for GMM attention. Defaults to 5.
|
||||
separate_stopnet (bool, optional): enable/disable separate stopnet training without only gradient
|
||||
flow from stopnet to the rest of the model. Defaults to True.
|
||||
bidirectional_decoder (bool, optional): enable/disable bidirectional decoding. Defaults to False.
|
||||
double_decoder_consistency (bool, optional): enable/disable double decoder consistency. Defaults to False.
|
||||
ddc_r (int, optional): reduction rate for the coarse decoder of double decoder consistency. Defaults to None.
|
||||
encoder_in_features (int, optional): input channels for the encoder. Defaults to 512.
|
||||
decoder_in_features (int, optional): input channels for the decoder. Defaults to 512.
|
||||
speaker_embedding_dim (int, optional): external speaker conditioning vector channels. Defaults to None.
|
||||
use_gst (bool, optional): enable/disable Global style token module.
|
||||
gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None.
|
||||
Check `TacotronConfig` for the arguments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim=80,
|
||||
decoder_output_dim=80,
|
||||
attn_type="original",
|
||||
attn_win=False,
|
||||
attn_norm="softmax",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
prenet_dropout_at_inference=False,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
encoder_in_features=512,
|
||||
decoder_in_features=512,
|
||||
speaker_embedding_dim=None,
|
||||
use_gst=False,
|
||||
gst=None,
|
||||
):
|
||||
super().__init__(
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim,
|
||||
decoder_output_dim,
|
||||
attn_type,
|
||||
attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
prenet_dropout_at_inference,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
bidirectional_decoder,
|
||||
double_decoder_consistency,
|
||||
ddc_r,
|
||||
encoder_in_features,
|
||||
decoder_in_features,
|
||||
speaker_embedding_dim,
|
||||
use_gst,
|
||||
gst,
|
||||
)
|
||||
def __init__(self, config: Coqpit):
|
||||
super().__init__(config)
|
||||
|
||||
chars, self.config = self.get_characters(config)
|
||||
self.num_chars = len(chars)
|
||||
self.decoder_output_dim = config.out_channels
|
||||
|
||||
# pass all config fields to `self`
|
||||
# for fewer code change
|
||||
for key in config:
|
||||
setattr(self, key, config[key])
|
||||
|
||||
# speaker embedding layer
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
speaker_embedding_dim = 512
|
||||
self.speaker_embedding = nn.Embedding(self.num_speakers, speaker_embedding_dim)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
self.init_multispeaker(config)
|
||||
|
||||
# speaker and gst embeddings is concat in decoder input
|
||||
if self.num_speakers > 1:
|
||||
self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim
|
||||
self.decoder_in_features += self.embedded_speaker_dim # add speaker embedding dim
|
||||
|
||||
if self.use_gst:
|
||||
self.decoder_in_features += self.gst.gst_embedding_dim
|
||||
|
||||
# embedding layer
|
||||
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
|
||||
self.embedding = nn.Embedding(self.num_chars, 512, padding_idx=0)
|
||||
|
||||
# base model layers
|
||||
self.encoder = Encoder(self.encoder_in_features)
|
||||
self.decoder = Decoder(
|
||||
self.decoder_in_features,
|
||||
self.decoder_output_dim,
|
||||
r,
|
||||
attn_type,
|
||||
attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
self.r,
|
||||
self.attention_type,
|
||||
self.attention_win,
|
||||
self.attention_norm,
|
||||
self.prenet_type,
|
||||
self.prenet_dropout,
|
||||
self.use_forward_attn,
|
||||
self.transition_agent,
|
||||
self.forward_attn_mask,
|
||||
self.location_attn,
|
||||
self.attention_heads,
|
||||
self.separate_stopnet,
|
||||
self.max_decoder_steps,
|
||||
)
|
||||
self.postnet = Postnet(self.postnet_output_dim)
|
||||
self.postnet = Postnet(self.out_channels)
|
||||
|
||||
# setup prenet dropout
|
||||
self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference
|
||||
self.decoder.prenet.dropout_at_inference = self.prenet_dropout_at_inference
|
||||
|
||||
# global style token layers
|
||||
if self.gst and use_gst:
|
||||
if self.gst and self.use_gst:
|
||||
self.gst_layer = GST(
|
||||
num_mel=decoder_output_dim,
|
||||
speaker_embedding_dim=speaker_embedding_dim,
|
||||
num_heads=gst.gst_num_heads,
|
||||
num_style_tokens=gst.gst_num_style_tokens,
|
||||
gst_embedding_dim=gst.gst_embedding_dim,
|
||||
num_mel=self.decoder_output_dim,
|
||||
d_vector_dim=self.d_vector_dim
|
||||
if self.config.gst.gst_use_speaker_embedding and self.use_speaker_embedding
|
||||
else None,
|
||||
num_heads=self.gst.gst_num_heads,
|
||||
num_style_tokens=self.gst.gst_num_style_tokens,
|
||||
gst_embedding_dim=self.gst.gst_embedding_dim,
|
||||
)
|
||||
|
||||
# backward pass decoder
|
||||
|
@ -156,18 +89,19 @@ class Tacotron2(TacotronAbstract):
|
|||
self.coarse_decoder = Decoder(
|
||||
self.decoder_in_features,
|
||||
self.decoder_output_dim,
|
||||
ddc_r,
|
||||
attn_type,
|
||||
attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
self.ddc_r,
|
||||
self.attention_type,
|
||||
self.attention_win,
|
||||
self.attention_norm,
|
||||
self.prenet_type,
|
||||
self.prenet_dropout,
|
||||
self.use_forward_attn,
|
||||
self.transition_agent,
|
||||
self.forward_attn_mask,
|
||||
self.location_attn,
|
||||
self.attention_heads,
|
||||
self.separate_stopnet,
|
||||
self.max_decoder_steps,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -176,16 +110,17 @@ class Tacotron2(TacotronAbstract):
|
|||
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
||||
return mel_outputs, mel_outputs_postnet, alignments
|
||||
|
||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None, speaker_embeddings=None):
|
||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, aux_input=None):
|
||||
"""
|
||||
Shapes:
|
||||
text: [B, T_in]
|
||||
text_lengths: [B]
|
||||
mel_specs: [B, T_out, C]
|
||||
mel_lengths: [B]
|
||||
speaker_ids: [B, 1]
|
||||
speaker_embeddings: [B, C]
|
||||
aux_input: 'speaker_ids': [B, 1] and 'd_vectors':[B, C]
|
||||
"""
|
||||
aux_input = self._format_aux_input(aux_input)
|
||||
outputs = {"alignments_backward": None, "decoder_outputs_backward": None}
|
||||
# compute mask for padding
|
||||
# B x T_in_max (boolean)
|
||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||
|
@ -195,15 +130,17 @@ class Tacotron2(TacotronAbstract):
|
|||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, speaker_embeddings)
|
||||
encoder_outputs = self.compute_gst(
|
||||
encoder_outputs, mel_specs, aux_input["d_vectors"] if "d_vectors" in aux_input else None
|
||||
)
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
if not self.use_d_vectors:
|
||||
# B x 1 x speaker_embed_dim
|
||||
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[:, None]
|
||||
else:
|
||||
# B x 1 x speaker_embed_dim
|
||||
speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1)
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||
embedded_speakers = torch.unsqueeze(aux_input["d_vectors"], 1)
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers)
|
||||
|
||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
||||
|
||||
|
@ -222,67 +159,140 @@ class Tacotron2(TacotronAbstract):
|
|||
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments)
|
||||
if self.bidirectional_decoder:
|
||||
decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
|
||||
return (
|
||||
decoder_outputs,
|
||||
postnet_outputs,
|
||||
alignments,
|
||||
stop_tokens,
|
||||
decoder_outputs_backward,
|
||||
alignments_backward,
|
||||
)
|
||||
outputs["alignments_backward"] = alignments_backward
|
||||
outputs["decoder_outputs_backward"] = decoder_outputs_backward
|
||||
if self.double_decoder_consistency:
|
||||
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
|
||||
mel_specs, encoder_outputs, alignments, input_mask
|
||||
)
|
||||
return (
|
||||
decoder_outputs,
|
||||
postnet_outputs,
|
||||
alignments,
|
||||
stop_tokens,
|
||||
decoder_outputs_backward,
|
||||
alignments_backward,
|
||||
)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
outputs["alignments_backward"] = alignments_backward
|
||||
outputs["decoder_outputs_backward"] = decoder_outputs_backward
|
||||
outputs.update(
|
||||
{
|
||||
"model_outputs": postnet_outputs,
|
||||
"decoder_outputs": decoder_outputs,
|
||||
"alignments": alignments,
|
||||
"stop_tokens": stop_tokens,
|
||||
}
|
||||
)
|
||||
return outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None):
|
||||
def inference(self, text, aux_input=None):
|
||||
aux_input = self._format_aux_input(aux_input)
|
||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
encoder_outputs = self.encoder.inference(embedded_inputs)
|
||||
|
||||
if self.gst and self.use_gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings)
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, aux_input["style_mel"], aux_input["d_vectors"])
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
||||
speaker_embeddings = torch.unsqueeze(speaker_embeddings, 0).transpose(1, 2)
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||
if not self.use_d_vectors:
|
||||
embedded_speakers = self.speaker_embedding(aux_input["speaker_ids"])[None]
|
||||
# reshape embedded_speakers
|
||||
if embedded_speakers.ndim == 1:
|
||||
embedded_speakers = embedded_speakers[None, None, :]
|
||||
elif embedded_speakers.ndim == 2:
|
||||
embedded_speakers = embedded_speakers[None, :]
|
||||
else:
|
||||
embedded_speakers = aux_input["d_vectors"]
|
||||
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, embedded_speakers)
|
||||
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs)
|
||||
postnet_outputs = self.postnet(decoder_outputs)
|
||||
postnet_outputs = decoder_outputs + postnet_outputs
|
||||
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
outputs = {
|
||||
"model_outputs": postnet_outputs,
|
||||
"decoder_outputs": decoder_outputs,
|
||||
"alignments": alignments,
|
||||
"stop_tokens": stop_tokens,
|
||||
}
|
||||
return outputs
|
||||
|
||||
def inference_truncated(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None):
|
||||
def train_step(self, batch, criterion):
|
||||
"""Perform a single training step by fetching the right set if samples from the batch.
|
||||
|
||||
Args:
|
||||
batch ([type]): [description]
|
||||
criterion ([type]): [description]
|
||||
"""
|
||||
Preserve model states for continuous inference
|
||||
"""
|
||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
|
||||
text_input = batch["text_input"]
|
||||
text_lengths = batch["text_lengths"]
|
||||
mel_input = batch["mel_input"]
|
||||
mel_lengths = batch["mel_lengths"]
|
||||
linear_input = batch["linear_input"]
|
||||
stop_targets = batch["stop_targets"]
|
||||
speaker_ids = batch["speaker_ids"]
|
||||
d_vectors = batch["d_vectors"]
|
||||
|
||||
if self.gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings)
|
||||
# forward pass model
|
||||
outputs = self.forward(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
aux_input={"speaker_ids": speaker_ids, "d_vectors": d_vectors},
|
||||
)
|
||||
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
||||
speaker_embeddings = torch.unsqueeze(speaker_embeddings, 0).transpose(1, 2)
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||
# set the [alignment] lengths wrt reduction factor for guided attention
|
||||
if mel_lengths.max() % self.decoder.r != 0:
|
||||
alignment_lengths = (
|
||||
mel_lengths + (self.decoder.r - (mel_lengths.max() % self.decoder.r))
|
||||
) // self.decoder.r
|
||||
else:
|
||||
alignment_lengths = mel_lengths // self.decoder.r
|
||||
|
||||
mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(encoder_outputs)
|
||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(mel_outputs, mel_outputs_postnet, alignments)
|
||||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
||||
aux_input = {"speaker_ids": speaker_ids, "d_vectors": d_vectors}
|
||||
outputs = self.forward(text_input, text_lengths, mel_input, mel_lengths, aux_input)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(
|
||||
outputs["model_outputs"],
|
||||
outputs["decoder_outputs"],
|
||||
mel_input,
|
||||
linear_input,
|
||||
outputs["stop_tokens"],
|
||||
stop_targets,
|
||||
mel_lengths,
|
||||
outputs["decoder_outputs_backward"],
|
||||
outputs["alignments"],
|
||||
alignment_lengths,
|
||||
outputs["alignments_backward"],
|
||||
text_lengths,
|
||||
)
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(outputs["alignments"])
|
||||
loss_dict["align_error"] = align_error
|
||||
return outputs, loss_dict
|
||||
|
||||
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict) -> Tuple[Dict, Dict]:
|
||||
postnet_outputs = outputs["model_outputs"]
|
||||
alignments = outputs["alignments"]
|
||||
alignments_backward = outputs["alignments_backward"]
|
||||
mel_input = batch["mel_input"]
|
||||
|
||||
pred_spec = postnet_outputs[0].data.cpu().numpy()
|
||||
gt_spec = mel_input[0].data.cpu().numpy()
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
|
||||
figures = {
|
||||
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||
"alignment": plot_alignment(align_img, output_fig=False),
|
||||
}
|
||||
|
||||
if self.bidirectional_decoder or self.double_decoder_consistency:
|
||||
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
|
||||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
return figures, {"audio": train_audio}
|
||||
|
||||
def eval_step(self, batch, criterion):
|
||||
return self.train_step(batch, criterion)
|
||||
|
||||
def eval_log(self, ap, batch, outputs):
|
||||
return self.train_log(ap, batch, outputs)
|
||||
|
|
|
@ -1,218 +0,0 @@
|
|||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
|
||||
|
||||
class TacotronAbstract(ABC, nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim=80,
|
||||
decoder_output_dim=80,
|
||||
attn_type="original",
|
||||
attn_win=False,
|
||||
attn_norm="softmax",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
prenet_dropout_at_inference=False,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
encoder_in_features=512,
|
||||
decoder_in_features=512,
|
||||
speaker_embedding_dim=None,
|
||||
use_gst=False,
|
||||
gst=None,
|
||||
):
|
||||
"""Abstract Tacotron class"""
|
||||
super().__init__()
|
||||
self.num_chars = num_chars
|
||||
self.r = r
|
||||
self.decoder_output_dim = decoder_output_dim
|
||||
self.postnet_output_dim = postnet_output_dim
|
||||
self.use_gst = use_gst
|
||||
self.gst = gst
|
||||
self.num_speakers = num_speakers
|
||||
self.bidirectional_decoder = bidirectional_decoder
|
||||
self.double_decoder_consistency = double_decoder_consistency
|
||||
self.ddc_r = ddc_r
|
||||
self.attn_type = attn_type
|
||||
self.attn_win = attn_win
|
||||
self.attn_norm = attn_norm
|
||||
self.prenet_type = prenet_type
|
||||
self.prenet_dropout = prenet_dropout
|
||||
self.prenet_dropout_at_inference = prenet_dropout_at_inference
|
||||
self.forward_attn = forward_attn
|
||||
self.trans_agent = trans_agent
|
||||
self.forward_attn_mask = forward_attn_mask
|
||||
self.location_attn = location_attn
|
||||
self.attn_K = attn_K
|
||||
self.separate_stopnet = separate_stopnet
|
||||
self.encoder_in_features = encoder_in_features
|
||||
self.decoder_in_features = decoder_in_features
|
||||
self.speaker_embedding_dim = speaker_embedding_dim
|
||||
|
||||
# layers
|
||||
self.embedding = None
|
||||
self.encoder = None
|
||||
self.decoder = None
|
||||
self.postnet = None
|
||||
|
||||
# multispeaker
|
||||
if self.speaker_embedding_dim is None:
|
||||
# if speaker_embedding_dim is None we need use the nn.Embedding, with default speaker_embedding_dim
|
||||
self.embeddings_per_sample = False
|
||||
else:
|
||||
# if speaker_embedding_dim is not None we need use speaker embedding per sample
|
||||
self.embeddings_per_sample = True
|
||||
|
||||
# global style token
|
||||
if self.gst and use_gst:
|
||||
self.decoder_in_features += self.gst.gst_embedding_dim # add gst embedding dim
|
||||
self.gst_layer = None
|
||||
|
||||
# model states
|
||||
self.speaker_embeddings = None
|
||||
self.speaker_embeddings_projected = None
|
||||
|
||||
# additional layers
|
||||
self.decoder_backward = None
|
||||
self.coarse_decoder = None
|
||||
|
||||
#############################
|
||||
# INIT FUNCTIONS
|
||||
#############################
|
||||
|
||||
def _init_states(self):
|
||||
self.speaker_embeddings = None
|
||||
self.speaker_embeddings_projected = None
|
||||
|
||||
def _init_backward_decoder(self):
|
||||
self.decoder_backward = copy.deepcopy(self.decoder)
|
||||
|
||||
def _init_coarse_decoder(self):
|
||||
self.coarse_decoder = copy.deepcopy(self.decoder)
|
||||
self.coarse_decoder.r_init = self.ddc_r
|
||||
self.coarse_decoder.set_r(self.ddc_r)
|
||||
|
||||
#############################
|
||||
# CORE FUNCTIONS
|
||||
#############################
|
||||
|
||||
@abstractmethod
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def inference(self):
|
||||
pass
|
||||
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
self.decoder.set_r(state["r"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
||||
#############################
|
||||
# COMMON COMPUTE FUNCTIONS
|
||||
#############################
|
||||
|
||||
def compute_masks(self, text_lengths, mel_lengths):
|
||||
"""Compute masks against sequence paddings."""
|
||||
# B x T_in_max (boolean)
|
||||
device = text_lengths.device
|
||||
input_mask = sequence_mask(text_lengths).to(device)
|
||||
output_mask = None
|
||||
if mel_lengths is not None:
|
||||
max_len = mel_lengths.max()
|
||||
r = self.decoder.r
|
||||
max_len = max_len + (r - (max_len % r)) if max_len % r > 0 else max_len
|
||||
output_mask = sequence_mask(mel_lengths, max_len=max_len).to(device)
|
||||
return input_mask, output_mask
|
||||
|
||||
def _backward_pass(self, mel_specs, encoder_outputs, mask):
|
||||
"""Run backwards decoder"""
|
||||
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
|
||||
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask
|
||||
)
|
||||
decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous()
|
||||
return decoder_outputs_b, alignments_b
|
||||
|
||||
def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask):
|
||||
"""Double Decoder Consistency"""
|
||||
T = mel_specs.shape[1]
|
||||
if T % self.coarse_decoder.r > 0:
|
||||
padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r)
|
||||
mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0))
|
||||
decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder(
|
||||
encoder_outputs.detach(), mel_specs, input_mask
|
||||
)
|
||||
# scale_factor = self.decoder.r_init / self.decoder.r
|
||||
alignments_backward = torch.nn.functional.interpolate(
|
||||
alignments_backward.transpose(1, 2), size=alignments.shape[1], mode="nearest"
|
||||
).transpose(1, 2)
|
||||
decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2)
|
||||
decoder_outputs_backward = decoder_outputs_backward[:, :T, :]
|
||||
return decoder_outputs_backward, alignments_backward
|
||||
|
||||
#############################
|
||||
# EMBEDDING FUNCTIONS
|
||||
#############################
|
||||
|
||||
def compute_speaker_embedding(self, speaker_ids):
|
||||
"""Compute speaker embedding vectors"""
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
||||
self.speaker_embeddings = self.speaker_embedding(speaker_ids).unsqueeze(1)
|
||||
if hasattr(self, "speaker_project_mel") and speaker_ids is not None:
|
||||
self.speaker_embeddings_projected = self.speaker_project_mel(self.speaker_embeddings).squeeze(1)
|
||||
|
||||
def compute_gst(self, inputs, style_input, speaker_embedding=None):
|
||||
"""Compute global style token"""
|
||||
device = inputs.device
|
||||
if isinstance(style_input, dict):
|
||||
query = torch.zeros(1, 1, self.gst.gst_embedding_dim // 2).to(device)
|
||||
if speaker_embedding is not None:
|
||||
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
|
||||
|
||||
_GST = torch.tanh(self.gst_layer.style_token_layer.style_tokens)
|
||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).to(device)
|
||||
for k_token, v_amplifier in style_input.items():
|
||||
key = _GST[int(k_token)].unsqueeze(0).expand(1, -1, -1)
|
||||
gst_outputs_att = self.gst_layer.style_token_layer.attention(query, key)
|
||||
gst_outputs = gst_outputs + gst_outputs_att * v_amplifier
|
||||
elif style_input is None:
|
||||
gst_outputs = torch.zeros(1, 1, self.gst.gst_embedding_dim).to(device)
|
||||
else:
|
||||
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
|
||||
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def _add_speaker_embedding(outputs, speaker_embeddings):
|
||||
speaker_embeddings_ = speaker_embeddings.expand(outputs.size(0), outputs.size(1), -1)
|
||||
outputs = outputs + speaker_embeddings_
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _concat_speaker_embedding(outputs, speaker_embeddings):
|
||||
speaker_embeddings_ = speaker_embeddings.expand(outputs.size(0), outputs.size(1), -1)
|
||||
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
|
||||
return outputs
|
|
@ -12,7 +12,7 @@ class Tacotron2(keras.models.Model):
|
|||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim=80,
|
||||
out_channels=80,
|
||||
decoder_output_dim=80,
|
||||
attn_type="original",
|
||||
attn_win=False,
|
||||
|
@ -31,7 +31,7 @@ class Tacotron2(keras.models.Model):
|
|||
super().__init__()
|
||||
self.r = r
|
||||
self.decoder_output_dim = decoder_output_dim
|
||||
self.postnet_output_dim = postnet_output_dim
|
||||
self.out_channels = out_channels
|
||||
self.bidirectional_decoder = bidirectional_decoder
|
||||
self.num_speakers = num_speakers
|
||||
self.speaker_embed_dim = 256
|
||||
|
@ -58,7 +58,7 @@ class Tacotron2(keras.models.Model):
|
|||
name="decoder",
|
||||
enable_tflite=enable_tflite,
|
||||
)
|
||||
self.postnet = Postnet(postnet_output_dim, 5, name="postnet")
|
||||
self.postnet = Postnet(out_channels, 5, name="postnet")
|
||||
|
||||
@tf.function(experimental_relax_shapes=True)
|
||||
def call(self, characters, text_lengths=None, frames=None, training=None):
|
||||
|
|
|
@ -44,8 +44,7 @@ def sequence_mask(sequence_length, max_len=None):
|
|||
batch_size = sequence_length.size(0)
|
||||
seq_range = np.empty([0, max_len], dtype=np.int8)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
if sequence_length.is_cuda:
|
||||
seq_range_expand = seq_range_expand.cuda()
|
||||
seq_range_expand = seq_range_expand.type_as(sequence_length)
|
||||
seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand)
|
||||
# B x T_max
|
||||
return seq_range_expand < seq_length_expand
|
||||
|
@ -84,7 +83,7 @@ def setup_model(num_chars, num_speakers, c, enable_tflite=False):
|
|||
num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
postnet_output_dim=c.audio["num_mels"],
|
||||
out_channels=c.audio["num_mels"],
|
||||
decoder_output_dim=c.audio["num_mels"],
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def _pad_data(x, length):
|
||||
|
@ -65,3 +66,12 @@ class StandardScaler:
|
|||
X *= self.scale_
|
||||
X += self.mean_
|
||||
return X
|
||||
|
||||
|
||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||
def sequence_mask(sequence_length, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = sequence_length.data.max()
|
||||
seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
|
||||
# B x T_max
|
||||
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
|
||||
|
|
|
@ -1,278 +0,0 @@
|
|||
import torch
|
||||
|
||||
from TTS.utils.generic_utils import find_module
|
||||
|
||||
|
||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||
def sequence_mask(sequence_length, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = sequence_length.data.max()
|
||||
seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
|
||||
# B x T_max
|
||||
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
|
||||
|
||||
|
||||
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
||||
print(" > Using model: {}".format(c.model))
|
||||
MyModel = find_module("TTS.tts.models", c.model.lower())
|
||||
if c.model.lower() in "tacotron":
|
||||
model = MyModel(
|
||||
num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
postnet_output_dim=int(c.audio["fft_size"] / 2 + 1),
|
||||
decoder_output_dim=c.audio["num_mels"],
|
||||
use_gst=c.use_gst,
|
||||
gst=c.gst,
|
||||
memory_size=c.memory_size,
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
prenet_dropout_at_inference=c.prenet_dropout_at_inference,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder,
|
||||
double_decoder_consistency=c.double_decoder_consistency,
|
||||
ddc_r=c.ddc_r,
|
||||
speaker_embedding_dim=speaker_embedding_dim,
|
||||
)
|
||||
elif c.model.lower() == "tacotron2":
|
||||
model = MyModel(
|
||||
num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
postnet_output_dim=c.audio["num_mels"],
|
||||
decoder_output_dim=c.audio["num_mels"],
|
||||
use_gst=c.use_gst,
|
||||
gst=c.gst,
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
prenet_dropout_at_inference=c.prenet_dropout_at_inference,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder,
|
||||
double_decoder_consistency=c.double_decoder_consistency,
|
||||
ddc_r=c.ddc_r,
|
||||
speaker_embedding_dim=speaker_embedding_dim,
|
||||
)
|
||||
elif c.model.lower() == "glow_tts":
|
||||
model = MyModel(
|
||||
num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
hidden_channels_enc=c["hidden_channels_encoder"],
|
||||
hidden_channels_dec=c["hidden_channels_decoder"],
|
||||
hidden_channels_dp=c["hidden_channels_duration_predictor"],
|
||||
out_channels=c.audio["num_mels"],
|
||||
encoder_type=c.encoder_type,
|
||||
encoder_params=c.encoder_params,
|
||||
use_encoder_prenet=c["use_encoder_prenet"],
|
||||
inference_noise_scale=c.inference_noise_scale,
|
||||
num_flow_blocks_dec=12,
|
||||
kernel_size_dec=5,
|
||||
dilation_rate=1,
|
||||
num_block_layers=4,
|
||||
dropout_p_dec=0.05,
|
||||
num_speakers=num_speakers,
|
||||
c_in_channels=0,
|
||||
num_splits=4,
|
||||
num_squeeze=2,
|
||||
sigmoid_scale=False,
|
||||
mean_only=True,
|
||||
speaker_embedding_dim=speaker_embedding_dim,
|
||||
)
|
||||
elif c.model.lower() == "speedy_speech":
|
||||
model = MyModel(
|
||||
num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
out_channels=c.audio["num_mels"],
|
||||
hidden_channels=c["hidden_channels"],
|
||||
positional_encoding=c["positional_encoding"],
|
||||
encoder_type=c["encoder_type"],
|
||||
encoder_params=c["encoder_params"],
|
||||
decoder_type=c["decoder_type"],
|
||||
decoder_params=c["decoder_params"],
|
||||
c_in_channels=0,
|
||||
)
|
||||
elif c.model.lower() == "align_tts":
|
||||
model = MyModel(
|
||||
num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
out_channels=c.audio["num_mels"],
|
||||
hidden_channels=c["hidden_channels"],
|
||||
hidden_channels_dp=c["hidden_channels_dp"],
|
||||
encoder_type=c["encoder_type"],
|
||||
encoder_params=c["encoder_params"],
|
||||
decoder_type=c["decoder_type"],
|
||||
decoder_params=c["decoder_params"],
|
||||
c_in_channels=0,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def is_tacotron(c):
|
||||
return "tacotron" in c["model"].lower()
|
||||
|
||||
|
||||
# def check_config_tts(c):
|
||||
# check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts', 'speedy_speech', 'align_tts'], restricted=True, val_type=str)
|
||||
# check_argument('run_name', c, restricted=True, val_type=str)
|
||||
# check_argument('run_description', c, val_type=str)
|
||||
|
||||
# # AUDIO
|
||||
# # check_argument('audio', c, restricted=True, val_type=dict)
|
||||
|
||||
# # audio processing parameters
|
||||
# # check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
|
||||
# # check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
|
||||
# # check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
|
||||
# # check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
|
||||
# # check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
|
||||
# # check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
|
||||
# # check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
|
||||
# # check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
|
||||
# # check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
|
||||
# # check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
|
||||
|
||||
# # vocabulary parameters
|
||||
# check_argument('characters', c, restricted=False, val_type=dict)
|
||||
# check_argument('pad', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
# check_argument('eos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
# check_argument('bos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
# check_argument('characters', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
# check_argument('phonemes', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys() and c['use_phonemes'], val_type=str)
|
||||
# check_argument('punctuations', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
|
||||
# # normalization parameters
|
||||
# # check_argument('signal_norm', c['audio'], restricted=True, val_type=bool)
|
||||
# # check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool)
|
||||
# # check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000)
|
||||
# # check_argument('clip_norm', c['audio'], restricted=True, val_type=bool)
|
||||
# # check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000)
|
||||
# # check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0)
|
||||
# # check_argument('spec_gain', c['audio'], restricted=True, val_type=[int, float], min_val=1, max_val=100)
|
||||
# # check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool)
|
||||
# # check_argument('trim_db', c['audio'], restricted=True, val_type=int)
|
||||
|
||||
# # training parameters
|
||||
# # check_argument('batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||
# # check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||
# # check_argument('r', c, restricted=True, val_type=int, min_val=1)
|
||||
# # check_argument('gradual_training', c, restricted=False, val_type=list)
|
||||
# # check_argument('mixed_precision', c, restricted=False, val_type=bool)
|
||||
# # check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
|
||||
|
||||
# # loss parameters
|
||||
# # check_argument('loss_masking', c, restricted=True, val_type=bool)
|
||||
# # if c['model'].lower() in ['tacotron', 'tacotron2']:
|
||||
# # check_argument('decoder_loss_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# # check_argument('postnet_loss_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# # check_argument('postnet_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# # check_argument('decoder_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# # check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# # check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# # check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# if c['model'].lower in ["speedy_speech", "align_tts"]:
|
||||
# check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
# check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
|
||||
# # validation parameters
|
||||
# # check_argument('run_eval', c, restricted=True, val_type=bool)
|
||||
# # check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0)
|
||||
# # check_argument('test_sentences_file', c, restricted=False, val_type=str)
|
||||
|
||||
# # optimizer
|
||||
# check_argument('noam_schedule', c, restricted=False, val_type=bool)
|
||||
# check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0)
|
||||
# check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
|
||||
# check_argument('lr', c, restricted=True, val_type=float, min_val=0)
|
||||
# check_argument('wd', c, restricted=is_tacotron(c), val_type=float, min_val=0)
|
||||
# check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
|
||||
# check_argument('seq_len_norm', c, restricted=is_tacotron(c), val_type=bool)
|
||||
|
||||
# # tacotron prenet
|
||||
# # check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1)
|
||||
# # check_argument('prenet_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['original', 'bn'])
|
||||
# # check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool)
|
||||
|
||||
# # attention
|
||||
# check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original', 'dynamic_convolution'])
|
||||
# check_argument('attention_heads', c, restricted=is_tacotron(c), val_type=int)
|
||||
# check_argument('attention_norm', c, restricted=is_tacotron(c), val_type=str, enum_list=['sigmoid', 'softmax'])
|
||||
# check_argument('windowing', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# check_argument('use_forward_attn', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# check_argument('forward_attn_mask', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# check_argument('location_attn', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# check_argument('bidirectional_decoder', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# check_argument('double_decoder_consistency', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int)
|
||||
|
||||
# if c['model'].lower() in ['tacotron', 'tacotron2']:
|
||||
# # stopnet
|
||||
# # check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# # check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool)
|
||||
|
||||
# # Model Parameters for non-tacotron models
|
||||
# if c['model'].lower in ["speedy_speech", "align_tts"]:
|
||||
# check_argument('positional_encoding', c, restricted=True, val_type=type)
|
||||
# check_argument('encoder_type', c, restricted=True, val_type=str)
|
||||
# check_argument('encoder_params', c, restricted=True, val_type=dict)
|
||||
# check_argument('decoder_residual_conv_bn_params', c, restricted=True, val_type=dict)
|
||||
|
||||
# # GlowTTS parameters
|
||||
# check_argument('encoder_type', c, restricted=not is_tacotron(c), val_type=str)
|
||||
|
||||
# # tensorboard
|
||||
# # check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
|
||||
# # check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1)
|
||||
# # check_argument('save_step', c, restricted=True, val_type=int, min_val=1)
|
||||
# # check_argument('checkpoint', c, restricted=True, val_type=bool)
|
||||
# # check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
|
||||
|
||||
# # dataloading
|
||||
# # pylint: disable=import-outside-toplevel
|
||||
# from TTS.tts.utils.text import cleaners
|
||||
# # check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners))
|
||||
# # check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)
|
||||
# # check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0)
|
||||
# # check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0)
|
||||
# # check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0)
|
||||
# # check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0)
|
||||
# # check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10)
|
||||
# # check_argument('compute_input_seq_cache', c, restricted=True, val_type=bool)
|
||||
|
||||
# # paths
|
||||
# # check_argument('output_path', c, restricted=True, val_type=str)
|
||||
|
||||
# # multi-speaker and gst
|
||||
# # check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
|
||||
# # check_argument('use_external_speaker_embedding_file', c, restricted=c['use_speaker_embedding'], val_type=bool)
|
||||
# # check_argument('external_speaker_embedding_file', c, restricted=c['use_external_speaker_embedding_file'], val_type=str)
|
||||
# if c['model'].lower() in ['tacotron', 'tacotron2'] and c['use_gst']:
|
||||
# # check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool)
|
||||
# # check_argument('gst', c, restricted=is_tacotron(c), val_type=dict)
|
||||
# # check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict])
|
||||
# # check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)
|
||||
# # check_argument('gst_use_speaker_embedding', c['gst'], restricted=is_tacotron(c), val_type=bool)
|
||||
# # check_argument('gst_num_heads', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10)
|
||||
# # check_argument('gst_num_style_tokens', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000)
|
||||
|
||||
# # datasets - checking only the first entry
|
||||
# # check_argument('datasets', c, restricted=True, val_type=list)
|
||||
# # for dataset_entry in c['datasets']:
|
||||
# # check_argument('name', dataset_entry, restricted=True, val_type=str)
|
||||
# # check_argument('path', dataset_entry, restricted=True, val_type=str)
|
||||
# # check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
|
||||
# # check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
|
@ -1,120 +0,0 @@
|
|||
import datetime
|
||||
import os
|
||||
import pickle as pickle_tts
|
||||
|
||||
import torch
|
||||
|
||||
from TTS.utils.io import RenamingUnpickler
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
|
||||
"""Load ```TTS.tts.models``` checkpoints.
|
||||
|
||||
Args:
|
||||
model (TTS.tts.models): model object to load the weights for.
|
||||
checkpoint_path (string): checkpoint file path.
|
||||
amp (apex.amp, optional): Apex amp abject to load apex related state vars. Defaults to None.
|
||||
use_cuda (bool, optional): load model to GPU if True. Defaults to False.
|
||||
|
||||
Returns:
|
||||
[type]: [description]
|
||||
"""
|
||||
try:
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
except ModuleNotFoundError:
|
||||
pickle_tts.Unpickler = RenamingUnpickler
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
|
||||
model.load_state_dict(state["model"])
|
||||
if amp and "amp" in state:
|
||||
amp.load_state_dict(state["amp"])
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
# set model stepsize
|
||||
if hasattr(model.decoder, "r"):
|
||||
model.decoder.set_r(state["r"])
|
||||
print(" > Model r: ", state["r"])
|
||||
if eval:
|
||||
model.eval()
|
||||
return model, state
|
||||
|
||||
|
||||
def save_model(model, optimizer, current_step, epoch, r, output_path, characters, amp_state_dict=None, **kwargs):
|
||||
"""Save ```TTS.tts.models``` states with extra fields.
|
||||
|
||||
Args:
|
||||
model (TTS.tts.models.Model): models object to be saved.
|
||||
optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training.
|
||||
current_step (int): current number of training steps.
|
||||
epoch (int): current number of training epochs.
|
||||
r (int): model reduction rate for Tacotron models.
|
||||
output_path (str): output path to save the model file.
|
||||
characters (list): list of characters used in the model.
|
||||
amp_state_dict (state_dict, optional): Apex.amp state dict if Apex is enabled. Defaults to None.
|
||||
"""
|
||||
if hasattr(model, "module"):
|
||||
model_state = model.module.state_dict()
|
||||
else:
|
||||
model_state = model.state_dict()
|
||||
state = {
|
||||
"model": model_state,
|
||||
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
||||
"step": current_step,
|
||||
"epoch": epoch,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
"r": r,
|
||||
"characters": characters,
|
||||
}
|
||||
if amp_state_dict:
|
||||
state["amp"] = amp_state_dict
|
||||
state.update(kwargs)
|
||||
torch.save(state, output_path)
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, characters, **kwargs):
|
||||
"""Save model checkpoint, intended for saving checkpoints at training.
|
||||
|
||||
Args:
|
||||
model (TTS.tts.models.Model): models object to be saved.
|
||||
optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training.
|
||||
current_step (int): current number of training steps.
|
||||
epoch (int): current number of training epochs.
|
||||
r (int): model reduction rate for Tacotron models.
|
||||
output_path (str): output path to save the model file.
|
||||
characters (list): list of characters used in the model.
|
||||
"""
|
||||
file_name = "checkpoint_{}.pth.tar".format(current_step)
|
||||
checkpoint_path = os.path.join(output_folder, file_name)
|
||||
print(" > CHECKPOINT : {}".format(checkpoint_path))
|
||||
save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, **kwargs)
|
||||
|
||||
|
||||
def save_best_model(
|
||||
target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, characters, **kwargs
|
||||
):
|
||||
"""Save model checkpoint, intended for saving the best model after each epoch.
|
||||
It compares the current model loss with the best loss so far and saves the
|
||||
model if the current loss is better.
|
||||
|
||||
Args:
|
||||
target_loss (float): current model loss.
|
||||
best_loss (float): best loss so far.
|
||||
model (TTS.tts.models.Model): models object to be saved.
|
||||
optimizer (torch.optim.optimizers.Optimizer): model optimizer used for training.
|
||||
current_step (int): current number of training steps.
|
||||
epoch (int): current number of training epochs.
|
||||
r (int): model reduction rate for Tacotron models.
|
||||
output_path (str): output path to save the model file.
|
||||
characters (list): list of characters used in the model.
|
||||
|
||||
Returns:
|
||||
float: updated current best loss.
|
||||
"""
|
||||
if target_loss < best_loss:
|
||||
file_name = "best_model.pth.tar"
|
||||
checkpoint_path = os.path.join(output_folder, file_name)
|
||||
print(" >> BEST MODEL : {}".format(checkpoint_path))
|
||||
save_model(
|
||||
model, optimizer, current_step, epoch, r, checkpoint_path, characters, model_loss=target_loss, **kwargs
|
||||
)
|
||||
best_loss = target_loss
|
||||
return best_loss
|
|
@ -1,164 +1,94 @@
|
|||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import Union
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.speaker_encoder.utils.generic_utils import setup_model
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
|
||||
def make_speakers_json_path(out_path):
|
||||
"""Returns conventional speakers.json location."""
|
||||
return os.path.join(out_path, "speakers.json")
|
||||
|
||||
|
||||
def load_speaker_mapping(out_path):
|
||||
"""Loads speaker mapping if already present."""
|
||||
if os.path.splitext(out_path)[1] == ".json":
|
||||
json_file = out_path
|
||||
else:
|
||||
json_file = make_speakers_json_path(out_path)
|
||||
with open(json_file) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def save_speaker_mapping(out_path, speaker_mapping):
|
||||
"""Saves speaker mapping if not yet present."""
|
||||
if out_path is not None:
|
||||
speakers_json_path = make_speakers_json_path(out_path)
|
||||
with open(speakers_json_path, "w") as f:
|
||||
json.dump(speaker_mapping, f, indent=4)
|
||||
|
||||
|
||||
def get_speakers(items):
|
||||
"""Returns a sorted, unique list of speakers in a given dataset."""
|
||||
speakers = {e[2] for e in items}
|
||||
return sorted(speakers)
|
||||
|
||||
|
||||
def parse_speakers(c, args, meta_data_train, OUT_PATH):
|
||||
"""Returns number of speakers, speaker embedding shape and speaker mapping"""
|
||||
if c.use_speaker_embedding:
|
||||
speakers = get_speakers(meta_data_train)
|
||||
if args.restore_path:
|
||||
if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file
|
||||
prev_out_path = os.path.dirname(args.restore_path)
|
||||
speaker_mapping = load_speaker_mapping(prev_out_path)
|
||||
if not speaker_mapping:
|
||||
print(
|
||||
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
|
||||
)
|
||||
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
|
||||
if not speaker_mapping:
|
||||
raise RuntimeError(
|
||||
"You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file"
|
||||
)
|
||||
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]["embedding"])
|
||||
elif (
|
||||
not c.use_external_speaker_embedding_file
|
||||
): # if restore checkpoint and don't use External Embedding file
|
||||
prev_out_path = os.path.dirname(args.restore_path)
|
||||
speaker_mapping = load_speaker_mapping(prev_out_path)
|
||||
speaker_embedding_dim = None
|
||||
assert all(speaker in speaker_mapping for speaker in speakers), (
|
||||
"As of now you, you cannot " "introduce new speakers to " "a previously trained model."
|
||||
)
|
||||
elif (
|
||||
c.use_external_speaker_embedding_file and c.external_speaker_embedding_file
|
||||
): # if start new train using External Embedding file
|
||||
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
|
||||
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]["embedding"])
|
||||
elif (
|
||||
c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file
|
||||
): # if start new train using External Embedding file and don't pass external embedding file
|
||||
raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
|
||||
else: # if start new train and don't use External Embedding file
|
||||
speaker_mapping = {name: i for i, name in enumerate(speakers)}
|
||||
speaker_embedding_dim = None
|
||||
save_speaker_mapping(OUT_PATH, speaker_mapping)
|
||||
num_speakers = len(speaker_mapping)
|
||||
print(" > Training with {} speakers: {}".format(len(speakers), ", ".join(speakers)))
|
||||
else:
|
||||
num_speakers = 0
|
||||
speaker_embedding_dim = None
|
||||
speaker_mapping = None
|
||||
|
||||
return num_speakers, speaker_embedding_dim, speaker_mapping
|
||||
|
||||
|
||||
class SpeakerManager:
|
||||
"""It manages the multi-speaker setup for 🐸TTS models. It loads the speaker files and parses the information
|
||||
in a way that you can query. There are 3 different scenarios considered.
|
||||
"""Manage the speakers for multi-speaker 🐸TTS models. Load a datafile and parse the information
|
||||
in a way that can be queried by speaker or clip.
|
||||
|
||||
1. Models using speaker embedding layers. The metafile only includes a mapping of speaker names to ids.
|
||||
2. Models using external embedding vectors (x vectors). The metafile includes a dictionary in the following
|
||||
format.
|
||||
There are 3 different scenarios considered:
|
||||
|
||||
```
|
||||
{
|
||||
'clip_name.wav':{
|
||||
'name': 'speakerA',
|
||||
'embedding'[<x_vector_values>]
|
||||
},
|
||||
...
|
||||
}
|
||||
```
|
||||
1. Models using speaker embedding layers. The datafile only maps speaker names to ids used by the embedding layer.
|
||||
2. Models using d-vectors. The datafile includes a dictionary in the following format.
|
||||
|
||||
3. Computing x vectors at inference with the speaker encoder. It loads the speaker encoder model and
|
||||
computes x vectors for a given instance.
|
||||
::
|
||||
|
||||
>>> >>> # load audio processor and speaker encoder
|
||||
>>> ap = AudioProcessor(**config.audio)
|
||||
>>> manager = SpeakerManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path)
|
||||
>>> # load a sample audio and compute embedding
|
||||
>>> waveform = ap.load_wav(sample_wav_path)
|
||||
>>> mel = ap.melspectrogram(waveform)
|
||||
>>> x_vector = manager.compute_x_vector(mel.T)
|
||||
{
|
||||
'clip_name.wav':{
|
||||
'name': 'speakerA',
|
||||
'embedding'[<d_vector_values>]
|
||||
},
|
||||
...
|
||||
}
|
||||
|
||||
|
||||
3. Computing the d-vectors by the speaker encoder. It loads the speaker encoder model and
|
||||
computes the d-vectors for a given clip or speaker.
|
||||
|
||||
Args:
|
||||
x_vectors_file_path (str, optional): Path to the metafile including x vectors. Defaults to "".
|
||||
speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by the
|
||||
TTS model. Defaults to "".
|
||||
d_vectors_file_path (str, optional): Path to the metafile including x vectors. Defaults to "".
|
||||
speaker_id_file_path (str, optional): Path to the metafile that maps speaker names to ids used by
|
||||
TTS models. Defaults to "".
|
||||
encoder_model_path (str, optional): Path to the speaker encoder model file. Defaults to "".
|
||||
encoder_config_path (str, optional): Path to the spealer encoder config file. Defaults to "".
|
||||
|
||||
Examples:
|
||||
>>> # load audio processor and speaker encoder
|
||||
>>> ap = AudioProcessor(**config.audio)
|
||||
>>> manager = SpeakerManager(encoder_model_path=encoder_model_path, encoder_config_path=encoder_config_path)
|
||||
>>> # load a sample audio and compute embedding
|
||||
>>> waveform = ap.load_wav(sample_wav_path)
|
||||
>>> mel = ap.melspectrogram(waveform)
|
||||
>>> d_vector = manager.compute_d_vector(mel.T)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x_vectors_file_path: str = "",
|
||||
data_items: List[List[Any]] = None,
|
||||
d_vectors_file_path: str = "",
|
||||
speaker_id_file_path: str = "",
|
||||
encoder_model_path: str = "",
|
||||
encoder_config_path: str = "",
|
||||
use_cuda: bool = False,
|
||||
):
|
||||
|
||||
self.x_vectors = None
|
||||
self.speaker_ids = None
|
||||
self.clip_ids = None
|
||||
self.data_items = []
|
||||
self.d_vectors = {}
|
||||
self.speaker_ids = {}
|
||||
self.clip_ids = []
|
||||
self.speaker_encoder = None
|
||||
self.speaker_encoder_ap = None
|
||||
self.use_cuda = use_cuda
|
||||
|
||||
if x_vectors_file_path:
|
||||
self.load_x_vectors_file(x_vectors_file_path)
|
||||
if data_items:
|
||||
self.speaker_ids, self.speaker_names, _ = self.parse_speakers_from_data(self.data_items)
|
||||
|
||||
if d_vectors_file_path:
|
||||
self.set_d_vectors_from_file(d_vectors_file_path)
|
||||
|
||||
if speaker_id_file_path:
|
||||
self.load_ids_file(speaker_id_file_path)
|
||||
self.set_speaker_ids_from_file(speaker_id_file_path)
|
||||
|
||||
if encoder_model_path and encoder_config_path:
|
||||
self.init_speaker_encoder(encoder_model_path, encoder_config_path)
|
||||
|
||||
@staticmethod
|
||||
def _load_json(json_file_path: str):
|
||||
def _load_json(json_file_path: str) -> Dict:
|
||||
with open(json_file_path) as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def _save_json(json_file_path: str, data: dict):
|
||||
def _save_json(json_file_path: str, data: dict) -> None:
|
||||
with open(json_file_path, "w") as f:
|
||||
json.dump(data, f, indent=4)
|
||||
|
||||
|
@ -167,54 +97,131 @@ class SpeakerManager:
|
|||
return len(self.speaker_ids)
|
||||
|
||||
@property
|
||||
def x_vector_dim(self):
|
||||
return len(self.x_vectors[list(self.x_vectors.keys())[0]]["embedding"])
|
||||
def speaker_names(self):
|
||||
return list(self.speaker_ids.keys())
|
||||
|
||||
def parser_speakers_from_items(self, items: list):
|
||||
speaker_ids = sorted({item[2] for item in items})
|
||||
self.speaker_ids = speaker_ids
|
||||
@property
|
||||
def d_vector_dim(self):
|
||||
"""Dimensionality of d_vectors. If d_vectors are not loaded, returns zero."""
|
||||
if self.d_vectors:
|
||||
return len(self.d_vectors[list(self.d_vectors.keys())[0]]["embedding"])
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def parse_speakers_from_data(items: list) -> Tuple[Dict, int]:
|
||||
"""Parse speaker IDs from data samples retured by `load_meta_data()`.
|
||||
|
||||
Args:
|
||||
items (list): Data sampled returned by `load_meta_data()`.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict, int]: speaker IDs and number of speakers.
|
||||
"""
|
||||
speakers = sorted({item[2] for item in items})
|
||||
speaker_ids = {name: i for i, name in enumerate(speakers)}
|
||||
num_speakers = len(speaker_ids)
|
||||
return speaker_ids, num_speakers
|
||||
|
||||
def save_ids_file(self, file_path: str):
|
||||
self._save_json(file_path, self.speaker_ids)
|
||||
def set_speaker_ids_from_data(self, items: List) -> None:
|
||||
"""Set speaker IDs from data samples.
|
||||
|
||||
def load_ids_file(self, file_path: str):
|
||||
Args:
|
||||
items (List): Data sampled returned by `load_meta_data()`.
|
||||
"""
|
||||
self.speaker_ids, _ = self.parse_speakers_from_data(items)
|
||||
|
||||
def set_speaker_ids_from_file(self, file_path: str) -> None:
|
||||
"""Set speaker IDs from a file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the file.
|
||||
"""
|
||||
self.speaker_ids = self._load_json(file_path)
|
||||
|
||||
def save_x_vectors_file(self, file_path: str):
|
||||
self._save_json(file_path, self.x_vectors)
|
||||
def save_speaker_ids_to_file(self, file_path: str) -> None:
|
||||
"""Save speaker IDs to a json file.
|
||||
|
||||
def load_x_vectors_file(self, file_path: str):
|
||||
self.x_vectors = self._load_json(file_path)
|
||||
self.speaker_ids = list(set(sorted(x["name"] for x in self.x_vectors.values())))
|
||||
self.clip_ids = list(set(sorted(clip_name for clip_name in self.x_vectors.keys())))
|
||||
Args:
|
||||
file_path (str): Path to the output file.
|
||||
"""
|
||||
self._save_json(file_path, self.speaker_ids)
|
||||
|
||||
def get_x_vector_by_clip(self, clip_idx: str):
|
||||
return self.x_vectors[clip_idx]["embedding"]
|
||||
def save_d_vectors_to_file(self, file_path: str) -> None:
|
||||
"""Save d_vectors to a json file.
|
||||
|
||||
def get_x_vectors_by_speaker(self, speaker_idx: str):
|
||||
return [x["embedding"] for x in self.x_vectors.values() if x["name"] == speaker_idx]
|
||||
Args:
|
||||
file_path (str): Path to the output file.
|
||||
"""
|
||||
self._save_json(file_path, self.d_vectors)
|
||||
|
||||
def get_mean_x_vector(self, speaker_idx: str, num_samples: int = None, randomize: bool = False):
|
||||
x_vectors = self.get_x_vectors_by_speaker(speaker_idx)
|
||||
def set_d_vectors_from_file(self, file_path: str) -> None:
|
||||
"""Load d_vectors from a json file.
|
||||
|
||||
Args:
|
||||
file_path (str): Path to the target json file.
|
||||
"""
|
||||
self.d_vectors = self._load_json(file_path)
|
||||
speakers = sorted({x["name"] for x in self.d_vectors.values()})
|
||||
self.speaker_ids = {name: i for i, name in enumerate(speakers)}
|
||||
self.clip_ids = list(set(sorted(clip_name for clip_name in self.d_vectors.keys())))
|
||||
|
||||
def get_d_vector_by_clip(self, clip_idx: str) -> List:
|
||||
"""Get d_vector by clip ID.
|
||||
|
||||
Args:
|
||||
clip_idx (str): Target clip ID.
|
||||
|
||||
Returns:
|
||||
List: d_vector as a list.
|
||||
"""
|
||||
return self.d_vectors[clip_idx]["embedding"]
|
||||
|
||||
def get_d_vectors_by_speaker(self, speaker_idx: str) -> List[List]:
|
||||
"""Get all d_vectors of a speaker.
|
||||
|
||||
Args:
|
||||
speaker_idx (str): Target speaker ID.
|
||||
|
||||
Returns:
|
||||
List[List]: all the d_vectors of the given speaker.
|
||||
"""
|
||||
return [x["embedding"] for x in self.d_vectors.values() if x["name"] == speaker_idx]
|
||||
|
||||
def get_mean_d_vector(self, speaker_idx: str, num_samples: int = None, randomize: bool = False) -> np.ndarray:
|
||||
"""Get mean d_vector of a speaker ID.
|
||||
|
||||
Args:
|
||||
speaker_idx (str): Target speaker ID.
|
||||
num_samples (int, optional): Number of samples to be averaged. Defaults to None.
|
||||
randomize (bool, optional): Pick random `num_samples` of d_vectors. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Mean d_vector.
|
||||
"""
|
||||
d_vectors = self.get_d_vectors_by_speaker(speaker_idx)
|
||||
if num_samples is None:
|
||||
x_vectors = np.stack(x_vectors).mean(0)
|
||||
d_vectors = np.stack(d_vectors).mean(0)
|
||||
else:
|
||||
assert len(x_vectors) >= num_samples, f" [!] speaker {speaker_idx} has number of samples < {num_samples}"
|
||||
assert len(d_vectors) >= num_samples, f" [!] speaker {speaker_idx} has number of samples < {num_samples}"
|
||||
if randomize:
|
||||
x_vectors = np.stack(random.choices(x_vectors, k=num_samples)).mean(0)
|
||||
d_vectors = np.stack(random.choices(d_vectors, k=num_samples)).mean(0)
|
||||
else:
|
||||
x_vectors = np.stack(x_vectors[:num_samples]).mean(0)
|
||||
return x_vectors
|
||||
d_vectors = np.stack(d_vectors[:num_samples]).mean(0)
|
||||
return d_vectors
|
||||
|
||||
def get_speakers(self):
|
||||
def get_speakers(self) -> List:
|
||||
return self.speaker_ids
|
||||
|
||||
def get_clips(self):
|
||||
return sorted(self.x_vectors.keys())
|
||||
def get_clips(self) -> List:
|
||||
return sorted(self.d_vectors.keys())
|
||||
|
||||
def init_speaker_encoder(self, model_path: str, config_path: str) -> None:
|
||||
"""Initialize a speaker encoder model.
|
||||
|
||||
Args:
|
||||
model_path (str): Model file path.
|
||||
config_path (str): Model config file path.
|
||||
"""
|
||||
self.speaker_encoder_config = load_config(config_path)
|
||||
self.speaker_encoder = setup_model(self.speaker_encoder_config)
|
||||
self.speaker_encoder.load_checkpoint(config_path, model_path, eval=True, use_cuda=self.use_cuda)
|
||||
|
@ -223,7 +230,16 @@ class SpeakerManager:
|
|||
# self.speaker_encoder_ap.do_sound_norm = True
|
||||
# self.speaker_encoder_ap.do_trim_silence = True
|
||||
|
||||
def compute_x_vector_from_clip(self, wav_file: Union[str, list]) -> list:
|
||||
def compute_d_vector_from_clip(self, wav_file: Union[str, list]) -> list:
|
||||
"""Compute a d_vector from a given audio file.
|
||||
|
||||
Args:
|
||||
wav_file (Union[str, list]): Target file path.
|
||||
|
||||
Returns:
|
||||
list: Computed d_vector.
|
||||
"""
|
||||
|
||||
def _compute(wav_file: str):
|
||||
waveform = self.speaker_encoder_ap.load_wav(wav_file, sr=self.speaker_encoder_ap.sample_rate)
|
||||
spec = self.speaker_encoder_ap.melspectrogram(waveform)
|
||||
|
@ -231,23 +247,31 @@ class SpeakerManager:
|
|||
if self.use_cuda:
|
||||
spec = spec.cuda()
|
||||
spec = spec.unsqueeze(0)
|
||||
x_vector = self.speaker_encoder.compute_embedding(spec)
|
||||
return x_vector
|
||||
d_vector = self.speaker_encoder.compute_embedding(spec)
|
||||
return d_vector
|
||||
|
||||
if isinstance(wav_file, list):
|
||||
# compute the mean x_vector
|
||||
x_vectors = None
|
||||
# compute the mean d_vector
|
||||
d_vectors = None
|
||||
for wf in wav_file:
|
||||
x_vector = _compute(wf)
|
||||
if x_vectors is None:
|
||||
x_vectors = x_vector
|
||||
d_vector = _compute(wf)
|
||||
if d_vectors is None:
|
||||
d_vectors = d_vector
|
||||
else:
|
||||
x_vectors += x_vector
|
||||
return (x_vectors / len(wav_file))[0].tolist()
|
||||
x_vector = _compute(wav_file)
|
||||
return x_vector[0].tolist()
|
||||
d_vectors += d_vector
|
||||
return (d_vectors / len(wav_file))[0].tolist()
|
||||
d_vector = _compute(wav_file)
|
||||
return d_vector[0].tolist()
|
||||
|
||||
def compute_x_vector(self, feats):
|
||||
def compute_d_vector(self, feats: Union[torch.Tensor, np.ndarray]) -> List:
|
||||
"""Compute d_vector from features.
|
||||
|
||||
Args:
|
||||
feats (Union[torch.Tensor, np.ndarray]): Input features.
|
||||
|
||||
Returns:
|
||||
List: computed d_vector.
|
||||
"""
|
||||
if isinstance(feats, np.ndarray):
|
||||
feats = torch.from_numpy(feats)
|
||||
if feats.ndim == 2:
|
||||
|
@ -263,3 +287,90 @@ class SpeakerManager:
|
|||
def plot_embeddings(self):
|
||||
# TODO: implement speaker encoder
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _set_file_path(path):
|
||||
"""Find the speakers.json under the given path or the above it.
|
||||
Intended to band aid the different paths returned in restored and continued training."""
|
||||
path_restore = os.path.join(os.path.dirname(path), "speakers.json")
|
||||
path_continue = os.path.join(path, "speakers.json")
|
||||
if os.path.exists(path_restore):
|
||||
return path_restore
|
||||
if os.path.exists(path_continue):
|
||||
return path_continue
|
||||
raise FileNotFoundError(f" [!] `speakers.json` not found in {path}")
|
||||
|
||||
|
||||
def load_speaker_mapping(out_path):
|
||||
"""Loads speaker mapping if already present."""
|
||||
if os.path.splitext(out_path)[1] == ".json":
|
||||
json_file = out_path
|
||||
else:
|
||||
json_file = _set_file_path(out_path)
|
||||
with open(json_file) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def save_speaker_mapping(out_path, speaker_mapping):
|
||||
"""Saves speaker mapping if not yet present."""
|
||||
if out_path is not None:
|
||||
speakers_json_path = _set_file_path(out_path)
|
||||
with open(speakers_json_path, "w") as f:
|
||||
json.dump(speaker_mapping, f, indent=4)
|
||||
|
||||
|
||||
def get_speaker_manager(c: Coqpit, data: List = None, restore_path: str = None, out_path: str = None) -> SpeakerManager:
|
||||
"""Initiate a `SpeakerManager` instance by the provided config.
|
||||
|
||||
Args:
|
||||
c (Coqpit): Model configuration.
|
||||
restore_path (str): Path to a previous training folder.
|
||||
data (List): Data samples used in training to infer speakers from. It must be provided if speaker embedding
|
||||
layers is used. Defaults to None.
|
||||
out_path (str, optional): Save the generated speaker IDs to a output path. Defaults to None.
|
||||
|
||||
Returns:
|
||||
SpeakerManager: initialized and ready to use instance.
|
||||
"""
|
||||
speaker_manager = SpeakerManager()
|
||||
if c.use_speaker_embedding:
|
||||
if data is not None:
|
||||
speaker_manager.set_speaker_ids_from_data(data)
|
||||
if restore_path:
|
||||
speakers_file = _set_file_path(restore_path)
|
||||
# restoring speaker manager from a previous run.
|
||||
if c.use_d_vector_file:
|
||||
# restore speaker manager with the embedding file
|
||||
if not os.path.exists(speakers_file):
|
||||
print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.d_vector_file")
|
||||
if not os.path.exists(c.d_vector_file):
|
||||
raise RuntimeError(
|
||||
"You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.d_vector_file"
|
||||
)
|
||||
speaker_manager.load_d_vectors_file(c.d_vector_file)
|
||||
speaker_manager.set_d_vectors_from_file(speakers_file)
|
||||
elif not c.use_d_vector_file: # restor speaker manager with speaker ID file.
|
||||
speaker_ids_from_data = speaker_manager.speaker_ids
|
||||
speaker_manager.set_speaker_ids_from_file(speakers_file)
|
||||
assert all(
|
||||
speaker in speaker_manager.speaker_ids for speaker in speaker_ids_from_data
|
||||
), " [!] You cannot introduce new speakers to a pre-trained model."
|
||||
elif c.use_d_vector_file and c.d_vector_file:
|
||||
# new speaker manager with external speaker embeddings.
|
||||
speaker_manager.set_d_vectors_from_file(c.d_vector_file)
|
||||
elif c.use_d_vector_file and not c.d_vector_file: # new speaker manager with speaker IDs file.
|
||||
raise "use_d_vector_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
|
||||
print(
|
||||
" > Training with {} speakers: {}".format(
|
||||
speaker_manager.num_speakers, ", ".join(speaker_manager.speaker_ids)
|
||||
)
|
||||
)
|
||||
# save file if path is defined
|
||||
if out_path:
|
||||
out_file_path = os.path.join(out_path, "speakers.json")
|
||||
print(f" > Saving `speakers.json` to {out_file_path}.")
|
||||
if c.use_d_vector_file and c.d_vector_file:
|
||||
speaker_manager.save_d_vectors_to_file(out_file_path)
|
||||
else:
|
||||
speaker_manager.save_speaker_ids_to_file(out_file_path)
|
||||
return speaker_manager
|
||||
|
|
|
@ -56,9 +56,6 @@ class SSIM(torch.nn.Module):
|
|||
window = self.window
|
||||
else:
|
||||
window = create_window(self.window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = window.type_as(img1)
|
||||
|
||||
self.window = window
|
||||
|
@ -69,10 +66,6 @@ class SSIM(torch.nn.Module):
|
|||
|
||||
def ssim(img1, img2, window_size=11, size_average=True):
|
||||
(_, channel, _, _) = img1.size()
|
||||
window = create_window(window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = create_window(window_size, channel).type_as(img1)
|
||||
window = window.type_as(img1)
|
||||
|
||||
return _ssim(img1, img2, window, window_size, channel, size_average)
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import os
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import pkg_resources
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .text import phoneme_to_sequence, text_to_sequence
|
||||
|
||||
|
@ -13,7 +15,7 @@ if "tensorflow" in installed or "tensorflow-gpu" in installed:
|
|||
import tensorflow as tf
|
||||
|
||||
|
||||
def text_to_seqvec(text, CONFIG):
|
||||
def text_to_seq(text, CONFIG):
|
||||
text_cleaner = [CONFIG.text_cleaner]
|
||||
# text ot phonemes to sequence vector
|
||||
if CONFIG.use_phonemes:
|
||||
|
@ -65,61 +67,45 @@ def compute_style_mel(style_wav, ap, cuda=False):
|
|||
return style_mel
|
||||
|
||||
|
||||
def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None):
|
||||
if "tacotron" in CONFIG.model.lower():
|
||||
if CONFIG.gst:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
else:
|
||||
if truncated:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
elif "glow" in CONFIG.model.lower():
|
||||
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
|
||||
if hasattr(model, "module"):
|
||||
# distributed model
|
||||
postnet_output, _, _, _, alignments, _, _ = model.module.inference(
|
||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
||||
)
|
||||
else:
|
||||
postnet_output, _, _, _, alignments, _, _ = model.inference(
|
||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
||||
)
|
||||
postnet_output = postnet_output.permute(0, 2, 1)
|
||||
# these only belong to tacotron models.
|
||||
decoder_output = None
|
||||
stop_tokens = None
|
||||
elif CONFIG.model.lower() in ["speedy_speech", "align_tts"]:
|
||||
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
|
||||
if hasattr(model, "module"):
|
||||
# distributed model
|
||||
postnet_output, alignments = model.module.inference(
|
||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
||||
)
|
||||
else:
|
||||
postnet_output, alignments = model.inference(
|
||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
||||
)
|
||||
postnet_output = postnet_output.permute(0, 2, 1)
|
||||
# these only belong to tacotron models.
|
||||
decoder_output = None
|
||||
stop_tokens = None
|
||||
def run_model_torch(
|
||||
model: nn.Module,
|
||||
inputs: torch.Tensor,
|
||||
speaker_id: int = None,
|
||||
style_mel: torch.Tensor = None,
|
||||
d_vector: torch.Tensor = None,
|
||||
) -> Dict:
|
||||
"""Run a torch model for inference. It does not support batch inference.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to run inference.
|
||||
inputs (torch.Tensor): Input tensor with character ids.
|
||||
speaker_id (int, optional): Input speaker ids for multi-speaker models. Defaults to None.
|
||||
style_mel (torch.Tensor, optional): Spectrograms used for voice styling . Defaults to None.
|
||||
d_vector (torch.Tensor, optional): d-vector for multi-speaker models . Defaults to None.
|
||||
|
||||
Returns:
|
||||
Dict: model outputs.
|
||||
"""
|
||||
input_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device)
|
||||
if hasattr(model, "module"):
|
||||
_func = model.module.inference
|
||||
else:
|
||||
raise ValueError("[!] Unknown model name.")
|
||||
return decoder_output, postnet_output, alignments, stop_tokens
|
||||
_func = model.inference
|
||||
outputs = _func(
|
||||
inputs,
|
||||
aux_input={
|
||||
"x_lengths": input_lengths,
|
||||
"speaker_ids": speaker_id,
|
||||
"d_vectors": d_vector,
|
||||
"style_mel": style_mel,
|
||||
},
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
|
||||
def run_model_tf(model, inputs, CONFIG, speaker_id=None, style_mel=None):
|
||||
if CONFIG.gst and style_mel is not None:
|
||||
raise NotImplementedError(" [!] GST inference not implemented for TF")
|
||||
if truncated:
|
||||
raise NotImplementedError(" [!] Truncated inference not implemented for TF")
|
||||
if speaker_id is not None:
|
||||
raise NotImplementedError(" [!] Multi-Speaker not implemented for TF")
|
||||
# TODO: handle multispeaker case
|
||||
|
@ -127,11 +113,9 @@ def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=No
|
|||
return decoder_output, postnet_output, alignments, stop_tokens
|
||||
|
||||
|
||||
def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
|
||||
def run_model_tflite(model, inputs, CONFIG, speaker_id=None, style_mel=None):
|
||||
if CONFIG.gst and style_mel is not None:
|
||||
raise NotImplementedError(" [!] GST inference not implemented for TfLite")
|
||||
if truncated:
|
||||
raise NotImplementedError(" [!] Truncated inference not implemented for TfLite")
|
||||
if speaker_id is not None:
|
||||
raise NotImplementedError(" [!] Multi-Speaker not implemented for TfLite")
|
||||
# get input and output details
|
||||
|
@ -152,14 +136,6 @@ def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_me
|
|||
return decoder_output, postnet_output, None, None
|
||||
|
||||
|
||||
def parse_outputs_torch(postnet_output, decoder_output, alignments, stop_tokens):
|
||||
postnet_output = postnet_output[0].data.cpu().numpy()
|
||||
decoder_output = None if decoder_output is None else decoder_output[0].data.cpu().numpy()
|
||||
alignment = alignments[0].cpu().data.numpy()
|
||||
stop_tokens = None if stop_tokens is None else stop_tokens[0].cpu().numpy()
|
||||
return postnet_output, decoder_output, alignment, stop_tokens
|
||||
|
||||
|
||||
def parse_outputs_tf(postnet_output, decoder_output, alignments, stop_tokens):
|
||||
postnet_output = postnet_output[0].numpy()
|
||||
decoder_output = decoder_output[0].numpy()
|
||||
|
@ -186,23 +162,22 @@ def inv_spectrogram(postnet_output, ap, CONFIG):
|
|||
return wav
|
||||
|
||||
|
||||
def id_to_torch(speaker_id, cuda=False):
|
||||
def speaker_id_to_torch(speaker_id, cuda=False):
|
||||
if speaker_id is not None:
|
||||
speaker_id = np.asarray(speaker_id)
|
||||
# TODO: test this for tacotron models
|
||||
speaker_id = torch.from_numpy(speaker_id)
|
||||
if cuda:
|
||||
return speaker_id.cuda()
|
||||
return speaker_id
|
||||
|
||||
|
||||
def embedding_to_torch(speaker_embedding, cuda=False):
|
||||
if speaker_embedding is not None:
|
||||
speaker_embedding = np.asarray(speaker_embedding)
|
||||
speaker_embedding = torch.from_numpy(speaker_embedding).unsqueeze(0).type(torch.FloatTensor)
|
||||
def embedding_to_torch(d_vector, cuda=False):
|
||||
if d_vector is not None:
|
||||
d_vector = np.asarray(d_vector)
|
||||
d_vector = torch.from_numpy(d_vector).unsqueeze(0).type(torch.FloatTensor)
|
||||
if cuda:
|
||||
return speaker_embedding.cuda()
|
||||
return speaker_embedding
|
||||
return d_vector.cuda()
|
||||
return d_vector
|
||||
|
||||
|
||||
# TODO: perform GL with pytorch for batching
|
||||
|
@ -231,11 +206,10 @@ def synthesis(
|
|||
ap,
|
||||
speaker_id=None,
|
||||
style_wav=None,
|
||||
truncated=False,
|
||||
enable_eos_bos_chars=False, # pylint: disable=unused-argument
|
||||
use_griffin_lim=False,
|
||||
do_trim_silence=False,
|
||||
speaker_embedding=None,
|
||||
d_vector=None,
|
||||
backend="torch",
|
||||
):
|
||||
"""Synthesize voice for the given text.
|
||||
|
@ -249,8 +223,6 @@ def synthesis(
|
|||
model outputs.
|
||||
speaker_id (int): id of speaker
|
||||
style_wav (str | Dict[str, float]): Uses for style embedding of GST.
|
||||
truncated (bool): keep model states after inference. It can be used
|
||||
for continuous inference at long texts.
|
||||
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
|
||||
do_trim_silence (bool): trim silence after synthesis.
|
||||
backend (str): tf or torch
|
||||
|
@ -263,54 +235,54 @@ def synthesis(
|
|||
else:
|
||||
style_mel = compute_style_mel(style_wav, ap, cuda=use_cuda)
|
||||
# preprocess the given text
|
||||
inputs = text_to_seqvec(text, CONFIG)
|
||||
text_inputs = text_to_seq(text, CONFIG)
|
||||
# pass tensors to backend
|
||||
if backend == "torch":
|
||||
if speaker_id is not None:
|
||||
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
|
||||
speaker_id = speaker_id_to_torch(speaker_id, cuda=use_cuda)
|
||||
|
||||
if speaker_embedding is not None:
|
||||
speaker_embedding = embedding_to_torch(speaker_embedding, cuda=use_cuda)
|
||||
if d_vector is not None:
|
||||
d_vector = embedding_to_torch(d_vector, cuda=use_cuda)
|
||||
|
||||
if not isinstance(style_mel, dict):
|
||||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||
inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda)
|
||||
inputs = inputs.unsqueeze(0)
|
||||
elif backend == "tf":
|
||||
text_inputs = numpy_to_torch(text_inputs, torch.long, cuda=use_cuda)
|
||||
text_inputs = text_inputs.unsqueeze(0)
|
||||
elif backend in ["tf", "tflite"]:
|
||||
# TODO: handle speaker id for tf model
|
||||
style_mel = numpy_to_tf(style_mel, tf.float32)
|
||||
inputs = numpy_to_tf(inputs, tf.int32)
|
||||
inputs = tf.expand_dims(inputs, 0)
|
||||
elif backend == "tflite":
|
||||
style_mel = numpy_to_tf(style_mel, tf.float32)
|
||||
inputs = numpy_to_tf(inputs, tf.int32)
|
||||
inputs = tf.expand_dims(inputs, 0)
|
||||
text_inputs = numpy_to_tf(text_inputs, tf.int32)
|
||||
text_inputs = tf.expand_dims(text_inputs, 0)
|
||||
# synthesize voice
|
||||
if backend == "torch":
|
||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_torch(
|
||||
model, inputs, CONFIG, truncated, speaker_id, style_mel, speaker_embeddings=speaker_embedding
|
||||
)
|
||||
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch(
|
||||
postnet_output, decoder_output, alignments, stop_tokens
|
||||
)
|
||||
outputs = run_model_torch(model, text_inputs, speaker_id, style_mel, d_vector=d_vector)
|
||||
model_outputs = outputs["model_outputs"]
|
||||
model_outputs = model_outputs[0].data.cpu().numpy()
|
||||
alignments = outputs["alignments"]
|
||||
elif backend == "tf":
|
||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
|
||||
model, inputs, CONFIG, truncated, speaker_id, style_mel
|
||||
model, text_inputs, CONFIG, speaker_id, style_mel
|
||||
)
|
||||
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf(
|
||||
model_outputs, decoder_output, alignments, stop_tokens = parse_outputs_tf(
|
||||
postnet_output, decoder_output, alignments, stop_tokens
|
||||
)
|
||||
elif backend == "tflite":
|
||||
decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite(
|
||||
model, inputs, CONFIG, truncated, speaker_id, style_mel
|
||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_tflite(
|
||||
model, text_inputs, CONFIG, speaker_id, style_mel
|
||||
)
|
||||
postnet_output, decoder_output = parse_outputs_tflite(postnet_output, decoder_output)
|
||||
model_outputs, decoder_output = parse_outputs_tflite(postnet_output, decoder_output)
|
||||
# convert outputs to numpy
|
||||
# plot results
|
||||
wav = None
|
||||
if use_griffin_lim:
|
||||
wav = inv_spectrogram(postnet_output, ap, CONFIG)
|
||||
wav = inv_spectrogram(model_outputs, ap, CONFIG)
|
||||
# trim silence
|
||||
if do_trim_silence:
|
||||
wav = trim_silence(wav, ap)
|
||||
return wav, alignment, decoder_output, postnet_output, stop_tokens, inputs
|
||||
return_dict = {
|
||||
"wav": wav,
|
||||
"alignments": alignments,
|
||||
"model_outputs": model_outputs,
|
||||
"text_inputs": text_inputs,
|
||||
}
|
||||
return return_dict
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# adapted from https://github.com/keithito/tacotron
|
||||
|
||||
import re
|
||||
import unicodedata
|
||||
|
|
|
@ -65,7 +65,7 @@ def basic_cleaners(text):
|
|||
|
||||
def transliteration_cleaners(text):
|
||||
"""Pipeline for non-English text that transliterates to ASCII."""
|
||||
text = convert_to_ascii(text)
|
||||
# text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
@ -89,7 +89,7 @@ def basic_turkish_cleaners(text):
|
|||
|
||||
def english_cleaners(text):
|
||||
"""Pipeline for English text, including number and abbreviation expansion."""
|
||||
text = convert_to_ascii(text)
|
||||
# text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_time_english(text)
|
||||
text = expand_numbers(text)
|
||||
|
@ -129,7 +129,7 @@ def chinese_mandarin_cleaners(text: str) -> str:
|
|||
def phoneme_cleaners(text):
|
||||
"""Pipeline for phonemes mode, including number and abbreviation expansion."""
|
||||
text = expand_numbers(text)
|
||||
text = convert_to_ascii(text)
|
||||
# text = convert_to_ascii(text)
|
||||
text = expand_abbreviations(text)
|
||||
text = replace_symbols(text)
|
||||
text = remove_aux_symbols(text)
|
||||
|
|
|
@ -1,183 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Argument parser for training scripts."""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.utils.text.symbols import parse_symbols
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.generic_utils import create_experiment_folder, get_git_branch
|
||||
from TTS.utils.io import copy_model_files
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
|
||||
|
||||
def init_arguments(argv):
|
||||
"""Parse command line arguments of training scripts.
|
||||
|
||||
Args:
|
||||
argv (list): This is a list of input arguments as given by sys.argv
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: Parsed arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--continue_path",
|
||||
type=str,
|
||||
help=(
|
||||
"Training output folder to continue training. Used to continue "
|
||||
"a training. If it is used, 'config_path' is ignored."
|
||||
),
|
||||
default="",
|
||||
required="--config_path" not in argv,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default=""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--best_path",
|
||||
type=str,
|
||||
help=(
|
||||
"Best model file to be used for extracting best loss."
|
||||
"If not specified, the latest best model in continue path is used"
|
||||
),
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in argv
|
||||
)
|
||||
parser.add_argument("--debug", type=bool, default=False, help="Do not verify commit integrity to run training.")
|
||||
parser.add_argument("--rank", type=int, default=0, help="DISTRIBUTED: process rank for distributed training.")
|
||||
parser.add_argument("--group_id", type=str, default="", help="DISTRIBUTED: process group id.")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_last_checkpoint(path):
|
||||
"""Get latest checkpoint or/and best model in path.
|
||||
|
||||
It is based on globbing for `*.pth.tar` and the RegEx
|
||||
`(checkpoint|best_model)_([0-9]+)`.
|
||||
|
||||
Args:
|
||||
path (list): Path to files to be compared.
|
||||
|
||||
Raises:
|
||||
ValueError: If no checkpoint or best_model files are found.
|
||||
|
||||
Returns:
|
||||
last_checkpoint (str): Last checkpoint filename.
|
||||
"""
|
||||
file_names = glob.glob(os.path.join(path, "*.pth.tar"))
|
||||
last_models = {}
|
||||
last_model_nums = {}
|
||||
for key in ["checkpoint", "best_model"]:
|
||||
last_model_num = None
|
||||
last_model = None
|
||||
# pass all the checkpoint files and find
|
||||
# the one with the largest model number suffix.
|
||||
for file_name in file_names:
|
||||
match = re.search(f"{key}_([0-9]+)", file_name)
|
||||
if match is not None:
|
||||
model_num = int(match.groups()[0])
|
||||
if last_model_num is None or model_num > last_model_num:
|
||||
last_model_num = model_num
|
||||
last_model = file_name
|
||||
|
||||
# if there is not checkpoint found above
|
||||
# find the checkpoint with the latest
|
||||
# modification date.
|
||||
key_file_names = [fn for fn in file_names if key in fn]
|
||||
if last_model is None and len(key_file_names) > 0:
|
||||
last_model = max(key_file_names, key=os.path.getctime)
|
||||
last_model_num = torch.load(last_model)["step"]
|
||||
|
||||
if last_model is not None:
|
||||
last_models[key] = last_model
|
||||
last_model_nums[key] = last_model_num
|
||||
|
||||
# check what models were found
|
||||
if not last_models:
|
||||
raise ValueError(f"No models found in continue path {path}!")
|
||||
if "checkpoint" not in last_models: # no checkpoint just best model
|
||||
last_models["checkpoint"] = last_models["best_model"]
|
||||
elif "best_model" not in last_models: # no best model
|
||||
# this shouldn't happen, but let's handle it just in case
|
||||
last_models["best_model"] = None
|
||||
# finally check if last best model is more recent than checkpoint
|
||||
elif last_model_nums["best_model"] > last_model_nums["checkpoint"]:
|
||||
last_models["checkpoint"] = last_models["best_model"]
|
||||
|
||||
return last_models["checkpoint"], last_models["best_model"]
|
||||
|
||||
|
||||
def process_args(args):
|
||||
"""Process parsed comand line arguments.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace or dict like): Parsed input arguments.
|
||||
|
||||
Returns:
|
||||
c (TTS.utils.io.AttrDict): Config paramaters.
|
||||
out_path (str): Path to save models and logging.
|
||||
audio_path (str): Path to save generated test audios.
|
||||
c_logger (TTS.utils.console_logger.ConsoleLogger): Class that does
|
||||
logging to the console.
|
||||
tb_logger (TTS.utils.tensorboard.TensorboardLogger): Class that does
|
||||
the TensorBoard loggind.
|
||||
"""
|
||||
if isinstance(args, tuple):
|
||||
args, coqpit_overrides = args
|
||||
if args.continue_path:
|
||||
# continue a previous training from its output folder
|
||||
experiment_path = args.continue_path
|
||||
args.config_path = os.path.join(args.continue_path, "config.json")
|
||||
args.restore_path, best_model = get_last_checkpoint(args.continue_path)
|
||||
if not args.best_path:
|
||||
args.best_path = best_model
|
||||
# setup output paths and read configs
|
||||
config = load_config(args.config_path)
|
||||
# override values from command-line args
|
||||
config.parse_known_args(coqpit_overrides, relaxed_parser=True)
|
||||
if config.mixed_precision:
|
||||
print(" > Mixed precision mode is ON")
|
||||
experiment_path = args.continue_path
|
||||
if not experiment_path:
|
||||
experiment_path = create_experiment_folder(config.output_path, config.run_name, args.debug)
|
||||
audio_path = os.path.join(experiment_path, "test_audios")
|
||||
# setup rank 0 process in distributed training
|
||||
tb_logger = None
|
||||
if args.rank == 0:
|
||||
os.makedirs(audio_path, exist_ok=True)
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
new_fields["restore_path"] = args.restore_path
|
||||
new_fields["github_branch"] = get_git_branch()
|
||||
# if model characters are not set in the config file
|
||||
# save the default set to the config file for future
|
||||
# compatibility.
|
||||
if config.has("characters_config"):
|
||||
used_characters = parse_symbols()
|
||||
new_fields["characters"] = used_characters
|
||||
copy_model_files(config, experiment_path, new_fields)
|
||||
os.chmod(audio_path, 0o775)
|
||||
os.chmod(experiment_path, 0o775)
|
||||
tb_logger = TensorboardLogger(experiment_path, model_name=config.model)
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text("model-config", f"<pre>{config.to_json()}</pre>", 0)
|
||||
c_logger = ConsoleLogger()
|
||||
return config, experiment_path, audio_path, c_logger, tb_logger
|
||||
|
||||
|
||||
def init_training(argv):
|
||||
"""Initialization of a training run."""
|
||||
parser = init_arguments(argv)
|
||||
args = parser.parse_known_args()
|
||||
config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args)
|
||||
return args[0], config, OUT_PATH, AUDIO_PATH, c_logger, tb_logger
|
|
@ -1,14 +1,93 @@
|
|||
from typing import Dict, Tuple
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import scipy.io.wavfile
|
||||
import scipy.signal
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from TTS.tts.utils.data import StandardScaler
|
||||
|
||||
# import pyworld as pw
|
||||
|
||||
|
||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||
"""TODO: Merge this with audio.py"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
pad_wav=False,
|
||||
window="hann_window",
|
||||
sample_rate=None,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
n_mels=80,
|
||||
use_mel=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.pad_wav = pad_wav
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_fmin = mel_fmin
|
||||
self.mel_fmax = mel_fmax
|
||||
self.n_mels = n_mels
|
||||
self.use_mel = use_mel
|
||||
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
||||
self.mel_basis = None
|
||||
if use_mel:
|
||||
self._build_mel_basis()
|
||||
|
||||
def __call__(self, x):
|
||||
"""Compute spectrogram frames by torch based stft.
|
||||
|
||||
Args:
|
||||
x (Tensor): input waveform
|
||||
|
||||
Returns:
|
||||
Tensor: spectrogram frames.
|
||||
|
||||
Shapes:
|
||||
x: [B x T] or [:math:`[B, 1, T]`]
|
||||
"""
|
||||
if x.ndim == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if self.pad_wav:
|
||||
padding = int((self.n_fft - self.hop_length) / 2)
|
||||
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
|
||||
# B x D x T x 2
|
||||
o = torch.stft(
|
||||
x.squeeze(1),
|
||||
self.n_fft,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
self.window,
|
||||
center=True,
|
||||
pad_mode="reflect", # compatible with audio.py
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=False,
|
||||
)
|
||||
M = o[:, :, :, 0]
|
||||
P = o[:, :, :, 1]
|
||||
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
|
||||
if self.use_mel:
|
||||
S = torch.matmul(self.mel_basis.to(x), S)
|
||||
return S
|
||||
|
||||
def _build_mel_basis(self):
|
||||
mel_basis = librosa.filters.mel(
|
||||
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
|
||||
)
|
||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
class AudioProcessor(object):
|
||||
"""Audio Processor for TTS used by all the data pipelines.
|
||||
|
@ -140,7 +219,12 @@ class AudioProcessor(object):
|
|||
### setting up the parameters ###
|
||||
def _build_mel_basis(
|
||||
self,
|
||||
):
|
||||
) -> np.ndarray:
|
||||
"""Build melspectrogram basis.
|
||||
|
||||
Returns:
|
||||
np.ndarray: melspectrogram basis.
|
||||
"""
|
||||
if self.mel_fmax is not None:
|
||||
assert self.mel_fmax <= self.sample_rate // 2
|
||||
return librosa.filters.mel(
|
||||
|
@ -149,8 +233,12 @@ class AudioProcessor(object):
|
|||
|
||||
def _stft_parameters(
|
||||
self,
|
||||
):
|
||||
"""Compute necessary stft parameters with given time values"""
|
||||
) -> Tuple[int, int]:
|
||||
"""Compute the real STFT parameters from the time values.
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: hop length and window length for STFT.
|
||||
"""
|
||||
factor = self.frame_length_ms / self.frame_shift_ms
|
||||
assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms"
|
||||
hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate)
|
||||
|
@ -158,8 +246,18 @@ class AudioProcessor(object):
|
|||
return hop_length, win_length
|
||||
|
||||
### normalization ###
|
||||
def normalize(self, S):
|
||||
"""Put values in [0, self.max_norm] or [-self.max_norm, self.max_norm]"""
|
||||
def normalize(self, S: np.ndarray) -> np.ndarray:
|
||||
"""Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]`
|
||||
|
||||
Args:
|
||||
S (np.ndarray): Spectrogram to normalize.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Mean and variance is computed from incompatible parameters.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized spectrogram.
|
||||
"""
|
||||
# pylint: disable=no-else-return
|
||||
S = S.copy()
|
||||
if self.signal_norm:
|
||||
|
@ -189,8 +287,18 @@ class AudioProcessor(object):
|
|||
else:
|
||||
return S
|
||||
|
||||
def denormalize(self, S):
|
||||
"""denormalize values"""
|
||||
def denormalize(self, S: np.ndarray) -> np.ndarray:
|
||||
"""Denormalize spectrogram values.
|
||||
|
||||
Args:
|
||||
S (np.ndarray): Spectrogram to denormalize.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Mean and variance are incompatible.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Denormalized spectrogram.
|
||||
"""
|
||||
# pylint: disable=no-else-return
|
||||
S_denorm = S.copy()
|
||||
if self.signal_norm:
|
||||
|
@ -218,7 +326,16 @@ class AudioProcessor(object):
|
|||
return S_denorm
|
||||
|
||||
### Mean-STD scaling ###
|
||||
def load_stats(self, stats_path):
|
||||
def load_stats(self, stats_path: str) -> Tuple[np.array, np.array, np.array, np.array, Dict]:
|
||||
"""Loading mean and variance statistics from a `npy` file.
|
||||
|
||||
Args:
|
||||
stats_path (str): Path to the `npy` file containing
|
||||
|
||||
Returns:
|
||||
Tuple[np.array, np.array, np.array, np.array, Dict]: loaded statistics and the config used to
|
||||
compute them.
|
||||
"""
|
||||
stats = np.load(stats_path, allow_pickle=True).item() # pylint: disable=unexpected-keyword-arg
|
||||
mel_mean = stats["mel_mean"]
|
||||
mel_std = stats["mel_std"]
|
||||
|
@ -237,7 +354,17 @@ class AudioProcessor(object):
|
|||
return mel_mean, mel_std, linear_mean, linear_std, stats_config
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
def setup_scaler(self, mel_mean, mel_std, linear_mean, linear_std):
|
||||
def setup_scaler(
|
||||
self, mel_mean: np.ndarray, mel_std: np.ndarray, linear_mean: np.ndarray, linear_std: np.ndarray
|
||||
) -> None:
|
||||
"""Initialize scaler objects used in mean-std normalization.
|
||||
|
||||
Args:
|
||||
mel_mean (np.ndarray): Mean for melspectrograms.
|
||||
mel_std (np.ndarray): STD for melspectrograms.
|
||||
linear_mean (np.ndarray): Mean for full scale spectrograms.
|
||||
linear_std (np.ndarray): STD for full scale spectrograms.
|
||||
"""
|
||||
self.mel_scaler = StandardScaler()
|
||||
self.mel_scaler.set_stats(mel_mean, mel_std)
|
||||
self.linear_scaler = StandardScaler()
|
||||
|
@ -245,32 +372,78 @@ class AudioProcessor(object):
|
|||
|
||||
### DB and AMP conversion ###
|
||||
# pylint: disable=no-self-use
|
||||
def _amp_to_db(self, x):
|
||||
def _amp_to_db(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Convert amplitude values to decibels.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Amplitude spectrogram.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decibels spectrogram.
|
||||
"""
|
||||
|
||||
return self.spec_gain * _log(np.maximum(1e-5, x), self.base)
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def _db_to_amp(self, x):
|
||||
def _db_to_amp(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Convert decibels spectrogram to amplitude spectrogram.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Decibels spectrogram.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Amplitude spectrogram.
|
||||
"""
|
||||
return _exp(x / self.spec_gain, self.base)
|
||||
|
||||
### Preemphasis ###
|
||||
def apply_preemphasis(self, x):
|
||||
def apply_preemphasis(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Audio signal.
|
||||
|
||||
Raises:
|
||||
RuntimeError: Preemphasis coeff is set to 0.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decorrelated audio signal.
|
||||
"""
|
||||
if self.preemphasis == 0:
|
||||
raise RuntimeError(" [!] Preemphasis is set 0.0.")
|
||||
return scipy.signal.lfilter([1, -self.preemphasis], [1], x)
|
||||
|
||||
def apply_inv_preemphasis(self, x):
|
||||
def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray:
|
||||
"""Reverse pre-emphasis."""
|
||||
if self.preemphasis == 0:
|
||||
raise RuntimeError(" [!] Preemphasis is set 0.0.")
|
||||
return scipy.signal.lfilter([1], [1, -self.preemphasis], x)
|
||||
|
||||
### SPECTROGRAMs ###
|
||||
def _linear_to_mel(self, spectrogram):
|
||||
def _linear_to_mel(self, spectrogram: np.ndarray) -> np.ndarray:
|
||||
"""Project a full scale spectrogram to a melspectrogram.
|
||||
|
||||
Args:
|
||||
spectrogram (np.ndarray): Full scale spectrogram.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Melspectrogram
|
||||
"""
|
||||
return np.dot(self.mel_basis, spectrogram)
|
||||
|
||||
def _mel_to_linear(self, mel_spec):
|
||||
def _mel_to_linear(self, mel_spec: np.ndarray) -> np.ndarray:
|
||||
"""Convert a melspectrogram to full scale spectrogram."""
|
||||
return np.maximum(1e-10, np.dot(self.inv_mel_basis, mel_spec))
|
||||
|
||||
def spectrogram(self, y):
|
||||
def spectrogram(self, y: np.ndarray) -> np.ndarray:
|
||||
"""Compute a spectrogram from a waveform.
|
||||
|
||||
Args:
|
||||
y (np.ndarray): Waveform.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Spectrogram.
|
||||
"""
|
||||
if self.preemphasis != 0:
|
||||
D = self._stft(self.apply_preemphasis(y))
|
||||
else:
|
||||
|
@ -278,7 +451,8 @@ class AudioProcessor(object):
|
|||
S = self._amp_to_db(np.abs(D))
|
||||
return self.normalize(S).astype(np.float32)
|
||||
|
||||
def melspectrogram(self, y):
|
||||
def melspectrogram(self, y: np.ndarray) -> np.ndarray:
|
||||
"""Compute a melspectrogram from a waveform."""
|
||||
if self.preemphasis != 0:
|
||||
D = self._stft(self.apply_preemphasis(y))
|
||||
else:
|
||||
|
@ -286,8 +460,8 @@ class AudioProcessor(object):
|
|||
S = self._amp_to_db(self._linear_to_mel(np.abs(D)))
|
||||
return self.normalize(S).astype(np.float32)
|
||||
|
||||
def inv_spectrogram(self, spectrogram):
|
||||
"""Converts spectrogram to waveform using librosa"""
|
||||
def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray:
|
||||
"""Convert a spectrogram to a waveform using Griffi-Lim vocoder."""
|
||||
S = self.denormalize(spectrogram)
|
||||
S = self._db_to_amp(S)
|
||||
# Reconstruct phase
|
||||
|
@ -295,8 +469,8 @@ class AudioProcessor(object):
|
|||
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
||||
return self._griffin_lim(S ** self.power)
|
||||
|
||||
def inv_melspectrogram(self, mel_spectrogram):
|
||||
"""Converts melspectrogram to waveform using librosa"""
|
||||
def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray:
|
||||
"""Convert a melspectrogram to a waveform using Griffi-Lim vocoder."""
|
||||
D = self.denormalize(mel_spectrogram)
|
||||
S = self._db_to_amp(D)
|
||||
S = self._mel_to_linear(S) # Convert back to linear
|
||||
|
@ -304,7 +478,15 @@ class AudioProcessor(object):
|
|||
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
||||
return self._griffin_lim(S ** self.power)
|
||||
|
||||
def out_linear_to_mel(self, linear_spec):
|
||||
def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray:
|
||||
"""Convert a full scale linear spectrogram output of a network to a melspectrogram.
|
||||
|
||||
Args:
|
||||
linear_spec (np.ndarray): Normalized full scale linear spectrogram.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized melspectrogram.
|
||||
"""
|
||||
S = self.denormalize(linear_spec)
|
||||
S = self._db_to_amp(S)
|
||||
S = self._linear_to_mel(np.abs(S))
|
||||
|
@ -313,7 +495,15 @@ class AudioProcessor(object):
|
|||
return mel
|
||||
|
||||
### STFT and ISTFT ###
|
||||
def _stft(self, y):
|
||||
def _stft(self, y: np.ndarray) -> np.ndarray:
|
||||
"""Librosa STFT wrapper.
|
||||
|
||||
Args:
|
||||
y (np.ndarray): Audio signal.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Complex number array.
|
||||
"""
|
||||
return librosa.stft(
|
||||
y=y,
|
||||
n_fft=self.fft_size,
|
||||
|
@ -324,7 +514,8 @@ class AudioProcessor(object):
|
|||
center=True,
|
||||
)
|
||||
|
||||
def _istft(self, y):
|
||||
def _istft(self, y: np.ndarray) -> np.ndarray:
|
||||
"""Librosa iSTFT wrapper."""
|
||||
return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length)
|
||||
|
||||
def _griffin_lim(self, S):
|
||||
|
@ -337,7 +528,8 @@ class AudioProcessor(object):
|
|||
return y
|
||||
|
||||
def compute_stft_paddings(self, x, pad_sides=1):
|
||||
"""compute right padding (final frame) or both sides padding (first and final frames)"""
|
||||
"""Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding
|
||||
(first and final frames)"""
|
||||
assert pad_sides in (1, 2)
|
||||
pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0]
|
||||
if pad_sides == 1:
|
||||
|
@ -357,7 +549,17 @@ class AudioProcessor(object):
|
|||
# return f0
|
||||
|
||||
### Audio Processing ###
|
||||
def find_endpoint(self, wav, threshold_db=-40, min_silence_sec=0.8):
|
||||
def find_endpoint(self, wav: np.ndarray, threshold_db=-40, min_silence_sec=0.8) -> int:
|
||||
"""Find the last point without silence at the end of a audio signal.
|
||||
|
||||
Args:
|
||||
wav (np.ndarray): Audio signal.
|
||||
threshold_db (int, optional): Silence threshold in decibels. Defaults to -40.
|
||||
min_silence_sec (float, optional): Ignore silences that are shorter then this in secs. Defaults to 0.8.
|
||||
|
||||
Returns:
|
||||
int: Last point without silence.
|
||||
"""
|
||||
window_length = int(self.sample_rate * min_silence_sec)
|
||||
hop_length = int(window_length / 4)
|
||||
threshold = self._db_to_amp(threshold_db)
|
||||
|
@ -375,11 +577,28 @@ class AudioProcessor(object):
|
|||
]
|
||||
|
||||
@staticmethod
|
||||
def sound_norm(x):
|
||||
def sound_norm(x: np.ndarray) -> np.ndarray:
|
||||
"""Normalize the volume of an audio signal.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Raw waveform.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Volume normalized waveform.
|
||||
"""
|
||||
return x / abs(x).max() * 0.95
|
||||
|
||||
### save and load ###
|
||||
def load_wav(self, filename, sr=None):
|
||||
def load_wav(self, filename: str, sr: int = None) -> np.ndarray:
|
||||
"""Read a wav file using Librosa and optionally resample, silence trim, volume normalize.
|
||||
|
||||
Args:
|
||||
filename (str): Path to the wav file.
|
||||
sr (int, optional): Sampling rate for resampling. Defaults to None.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Loaded waveform.
|
||||
"""
|
||||
if self.resample:
|
||||
x, sr = librosa.load(filename, sr=self.sample_rate)
|
||||
elif sr is None:
|
||||
|
@ -396,12 +615,19 @@ class AudioProcessor(object):
|
|||
x = self.sound_norm(x)
|
||||
return x
|
||||
|
||||
def save_wav(self, wav, path, sr=None):
|
||||
def save_wav(self, wav: np.ndarray, path: str, sr: int = None) -> None:
|
||||
"""Save a waveform to a file using Scipy.
|
||||
|
||||
Args:
|
||||
wav (np.ndarray): Waveform to save.
|
||||
path (str): Path to a output file.
|
||||
sr (int, optional): Sampling rate used for saving to the file. Defaults to None.
|
||||
"""
|
||||
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
|
||||
scipy.io.wavfile.write(path, sr if sr else self.sample_rate, wav_norm.astype(np.int16))
|
||||
|
||||
@staticmethod
|
||||
def mulaw_encode(wav, qc):
|
||||
def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray:
|
||||
mu = 2 ** qc - 1
|
||||
# wav_abs = np.minimum(np.abs(wav), 1.0)
|
||||
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
|
||||
|
@ -423,11 +649,21 @@ class AudioProcessor(object):
|
|||
return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16)
|
||||
|
||||
@staticmethod
|
||||
def quantize(x, bits):
|
||||
def quantize(x: np.ndarray, bits: int) -> np.ndarray:
|
||||
"""Quantize a waveform to a given number of bits.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`.
|
||||
bits (int): Number of quantization bits.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Quantized waveform.
|
||||
"""
|
||||
return (x + 1.0) * (2 ** bits - 1) / 2
|
||||
|
||||
@staticmethod
|
||||
def dequantize(x, bits):
|
||||
"""Dequantize a waveform from the given number of bits."""
|
||||
return 2 * x / (2 ** bits - 1) - 1
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
class TrainerCallback:
|
||||
def __init__(self, trainer):
|
||||
super().__init__()
|
||||
self.trainer = trainer
|
||||
|
||||
def on_init_start(self) -> None:
|
||||
if hasattr(self.trainer.model, "on_init_start"):
|
||||
self.trainer.model.on_init_start(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_init_start"):
|
||||
self.trainer.criterion.on_init_start(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_init_start"):
|
||||
self.trainer.optimizer.on_init_start(self.trainer)
|
||||
|
||||
def on_init_end(self) -> None:
|
||||
if hasattr(self.trainer.model, "on_init_end"):
|
||||
self.trainer.model.on_init_end(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_init_end"):
|
||||
self.trainer.criterion.on_init_end(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_init_end"):
|
||||
self.trainer.optimizer.on_init_end(self.trainer)
|
||||
|
||||
def on_epoch_start(self) -> None:
|
||||
if hasattr(self.trainer.model, "on_epoch_start"):
|
||||
self.trainer.model.on_epoch_start(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_epoch_start"):
|
||||
self.trainer.criterion.on_epoch_start(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_epoch_start"):
|
||||
self.trainer.optimizer.on_epoch_start(self.trainer)
|
||||
|
||||
def on_epoch_end(self) -> None:
|
||||
if hasattr(self.trainer.model, "on_epoch_end"):
|
||||
self.trainer.model.on_epoch_end(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_epoch_end"):
|
||||
self.trainer.criterion.on_epoch_end(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_epoch_end"):
|
||||
self.trainer.optimizer.on_epoch_end(self.trainer)
|
||||
|
||||
def on_train_step_start(self) -> None:
|
||||
if hasattr(self.trainer.model, "on_train_step_start"):
|
||||
self.trainer.model.on_train_step_start(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_train_step_start"):
|
||||
self.trainer.criterion.on_train_step_start(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_train_step_start"):
|
||||
self.trainer.optimizer.on_train_step_start(self.trainer)
|
||||
|
||||
def on_train_step_end(self) -> None:
|
||||
|
||||
if hasattr(self.trainer.model, "on_train_step_end"):
|
||||
self.trainer.model.on_train_step_end(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_train_step_end"):
|
||||
self.trainer.criterion.on_train_step_end(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_train_step_end"):
|
||||
self.trainer.optimizer.on_train_step_end(self.trainer)
|
||||
|
||||
def on_keyboard_interrupt(self) -> None:
|
||||
if hasattr(self.trainer.model, "on_keyboard_interrupt"):
|
||||
self.trainer.model.on_keyboard_interrupt(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.criterion, "on_keyboard_interrupt"):
|
||||
self.trainer.criterion.on_keyboard_interrupt(self.trainer)
|
||||
|
||||
if hasattr(self.trainer.optimizer, "on_keyboard_interrupt"):
|
||||
self.trainer.optimizer.on_keyboard_interrupt(self.trainer)
|
|
@ -1,53 +1,8 @@
|
|||
# edited from https://github.com/fastai/imagenet-fast/blob/master/imagenet_nv/distributed.py
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
from torch.autograd import Variable
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
|
||||
class DistributedSampler(Sampler):
|
||||
"""
|
||||
Non shuffling Distributed Sampler
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None):
|
||||
super().__init__(dataset)
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = dist.get_rank()
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
|
||||
def __iter__(self):
|
||||
indices = torch.arange(len(self.dataset)).tolist()
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[: (self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
|
||||
def reduce_tensor(tensor, num_gpus):
|
||||
|
|
|
@ -8,10 +8,21 @@ import shutil
|
|||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def to_cuda(x: torch.Tensor) -> torch.Tensor:
|
||||
if x is None:
|
||||
return None
|
||||
if torch.is_tensor(x):
|
||||
x = x.contiguous()
|
||||
if torch.cuda.is_available():
|
||||
x = x.cuda(non_blocking=True)
|
||||
return x
|
||||
|
||||
|
||||
def get_cuda():
|
||||
use_cuda = torch.cuda.is_available()
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
@ -47,13 +58,10 @@ def get_commit_hash():
|
|||
return commit
|
||||
|
||||
|
||||
def create_experiment_folder(root_path, model_name, debug):
|
||||
def create_experiment_folder(root_path, model_name):
|
||||
"""Create a folder with the current date and time"""
|
||||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
||||
if debug:
|
||||
commit_hash = "debug"
|
||||
else:
|
||||
commit_hash = get_commit_hash()
|
||||
commit_hash = get_commit_hash()
|
||||
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
print(" > Experiment folder: {}".format(output_folder))
|
||||
|
@ -126,6 +134,22 @@ def set_init_dict(model_dict, checkpoint_state, c):
|
|||
return model_dict
|
||||
|
||||
|
||||
def format_aux_input(def_args: Dict, kwargs: Dict) -> Dict:
|
||||
"""Format kwargs to hande auxilary inputs to models.
|
||||
|
||||
Args:
|
||||
def_args (Dict): A dictionary of argument names and their default values if not defined in `kwargs`.
|
||||
kwargs (Dict): A `dict` or `kwargs` that includes auxilary inputs to the model.
|
||||
|
||||
Returns:
|
||||
Dict: arguments with formatted auxilary inputs.
|
||||
"""
|
||||
for name in def_args:
|
||||
if name not in kwargs:
|
||||
kwargs[def_args[name]] = None
|
||||
return kwargs
|
||||
|
||||
|
||||
class KeepAverage:
|
||||
def __init__(self):
|
||||
self.avg_values = {}
|
||||
|
|
121
TTS/utils/io.py
121
TTS/utils/io.py
|
@ -1,7 +1,12 @@
|
|||
import datetime
|
||||
import glob
|
||||
import os
|
||||
import pickle as pickle_tts
|
||||
from shutil import copyfile
|
||||
|
||||
import torch
|
||||
from coqpit import Coqpit
|
||||
|
||||
|
||||
class RenamingUnpickler(pickle_tts.Unpickler):
|
||||
"""Overload default pickler to solve module renaming problem"""
|
||||
|
@ -41,3 +46,119 @@ def copy_model_files(config, out_path, new_fields):
|
|||
config.audio.stats_path,
|
||||
copy_stats_path,
|
||||
)
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
|
||||
try:
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
except ModuleNotFoundError:
|
||||
pickle_tts.Unpickler = RenamingUnpickler
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
|
||||
model.load_state_dict(state["model"])
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
if eval:
|
||||
model.eval()
|
||||
return model, state
|
||||
|
||||
|
||||
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs):
|
||||
if hasattr(model, "module"):
|
||||
model_state = model.module.state_dict()
|
||||
else:
|
||||
model_state = model.state_dict()
|
||||
if isinstance(optimizer, list):
|
||||
optimizer_state = [optim.state_dict() for optim in optimizer]
|
||||
else:
|
||||
optimizer_state = optimizer.state_dict() if optimizer is not None else None
|
||||
|
||||
if isinstance(scaler, list):
|
||||
scaler_state = [s.state_dict() for s in scaler]
|
||||
else:
|
||||
scaler_state = scaler.state_dict() if scaler is not None else None
|
||||
|
||||
if isinstance(config, Coqpit):
|
||||
config = config.to_dict()
|
||||
|
||||
state = {
|
||||
"config": config,
|
||||
"model": model_state,
|
||||
"optimizer": optimizer_state,
|
||||
"scaler": scaler_state,
|
||||
"step": current_step,
|
||||
"epoch": epoch,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
state.update(kwargs)
|
||||
torch.save(state, output_path)
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
config,
|
||||
model,
|
||||
optimizer,
|
||||
scaler,
|
||||
current_step,
|
||||
epoch,
|
||||
output_folder,
|
||||
**kwargs,
|
||||
):
|
||||
file_name = "checkpoint_{}.pth.tar".format(current_step)
|
||||
checkpoint_path = os.path.join(output_folder, file_name)
|
||||
print("\n > CHECKPOINT : {}".format(checkpoint_path))
|
||||
save_model(
|
||||
config,
|
||||
model,
|
||||
optimizer,
|
||||
scaler,
|
||||
current_step,
|
||||
epoch,
|
||||
checkpoint_path,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def save_best_model(
|
||||
current_loss,
|
||||
best_loss,
|
||||
config,
|
||||
model,
|
||||
optimizer,
|
||||
scaler,
|
||||
current_step,
|
||||
epoch,
|
||||
out_path,
|
||||
keep_all_best=False,
|
||||
keep_after=10000,
|
||||
**kwargs,
|
||||
):
|
||||
if current_loss < best_loss:
|
||||
best_model_name = f"best_model_{current_step}.pth.tar"
|
||||
checkpoint_path = os.path.join(out_path, best_model_name)
|
||||
print(" > BEST MODEL : {}".format(checkpoint_path))
|
||||
save_model(
|
||||
config,
|
||||
model,
|
||||
optimizer,
|
||||
scaler,
|
||||
current_step,
|
||||
epoch,
|
||||
checkpoint_path,
|
||||
model_loss=current_loss,
|
||||
**kwargs,
|
||||
)
|
||||
# only delete previous if current is saved successfully
|
||||
if not keep_all_best or (current_step < keep_after):
|
||||
model_names = glob.glob(os.path.join(out_path, "best_model*.pth.tar"))
|
||||
for model_name in model_names:
|
||||
if os.path.basename(model_name) == best_model_name:
|
||||
continue
|
||||
os.remove(model_name)
|
||||
# create symlink to best model for convinience
|
||||
link_name = "best_model.pth.tar"
|
||||
link_path = os.path.join(out_path, link_name)
|
||||
if os.path.islink(link_path) or os.path.isfile(link_path):
|
||||
os.remove(link_path)
|
||||
os.symlink(best_model_name, os.path.join(out_path, link_name))
|
||||
best_loss = current_loss
|
||||
return best_loss
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
from TTS.utils.logging.console_logger import ConsoleLogger
|
||||
from TTS.utils.logging.tensorboard_logger import TensorboardLogger
|
|
@ -68,11 +68,10 @@ class ConsoleLogger:
|
|||
print(log_text, flush=True)
|
||||
|
||||
def print_eval_start(self):
|
||||
print(f"{tcolors.BOLD} > EVALUATION {tcolors.ENDC}\n")
|
||||
print(f"\n{tcolors.BOLD} > EVALUATION {tcolors.ENDC}\n")
|
||||
|
||||
def print_eval_step(self, step, loss_dict, avg_loss_dict):
|
||||
indent = " | > "
|
||||
print()
|
||||
log_text = f"{tcolors.BOLD} --> STEP: {step}{tcolors.ENDC}\n"
|
||||
for key, value in loss_dict.items():
|
||||
# print the avg value if given
|
||||
|
@ -84,7 +83,7 @@ class ConsoleLogger:
|
|||
|
||||
def print_epoch_end(self, epoch, avg_loss_dict):
|
||||
indent = " | > "
|
||||
log_text = " {}--> EVAL PERFORMANCE{}\n".format(tcolors.BOLD, tcolors.ENDC)
|
||||
log_text = "\n {}--> EVAL PERFORMANCE{}\n".format(tcolors.BOLD, tcolors.ENDC)
|
||||
for key, value in avg_loss_dict.items():
|
||||
# print the avg value if given
|
||||
color = ""
|
|
@ -34,12 +34,14 @@ class TensorboardLogger(object):
|
|||
|
||||
def dict_to_tb_audios(self, scope_name, audios, step, sample_rate):
|
||||
for key, value in audios.items():
|
||||
if value.dtype == "float16":
|
||||
value = value.astype("float32")
|
||||
try:
|
||||
self.writer.add_audio("{}/{}".format(scope_name, key), value, step, sample_rate=sample_rate)
|
||||
except RuntimeError:
|
||||
traceback.print_exc()
|
||||
|
||||
def tb_train_iter_stats(self, step, stats):
|
||||
def tb_train_step_stats(self, step, stats):
|
||||
self.dict_to_tb_scalar(f"{self.model_name}_TrainIterStats", stats, step)
|
||||
|
||||
def tb_train_epoch_stats(self, step, stats):
|
|
@ -3,7 +3,7 @@ import json
|
|||
import os
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from shutil import copyfile, rmtree
|
||||
|
||||
import gdown
|
||||
import requests
|
||||
|
@ -83,7 +83,7 @@ class ModelManager(object):
|
|||
'type/language/dataset/model'
|
||||
e.g. 'tts_model/en/ljspeech/tacotron'
|
||||
|
||||
Every model must have the following files
|
||||
Every model must have the following files:
|
||||
- *.pth.tar : pytorch model checkpoint file.
|
||||
- config.json : model config file.
|
||||
- scale_stats.npy (if exist): scale values for preprocessing.
|
||||
|
@ -101,11 +101,7 @@ class ModelManager(object):
|
|||
output_path = os.path.join(self.output_prefix, model_full_name)
|
||||
output_model_path = os.path.join(output_path, "model_file.pth.tar")
|
||||
output_config_path = os.path.join(output_path, "config.json")
|
||||
# NOTE : band-aid for removing phoneme support
|
||||
# if "needs_phonemizer" in model_item and model_item["needs_phonemizer"]:
|
||||
# raise RuntimeError(
|
||||
# " [!] Use 🐸TTS <= v0.0.13 for this model. Current version does not support phoneme based models."
|
||||
# )
|
||||
|
||||
if os.path.exists(output_path):
|
||||
print(f" > {model_name} is already downloaded.")
|
||||
else:
|
||||
|
@ -116,7 +112,6 @@ class ModelManager(object):
|
|||
# download files to the output path
|
||||
if self._check_dict_key(model_item, "github_rls_url"):
|
||||
# download from github release
|
||||
# TODO: pass output_path
|
||||
self._download_zip_file(model_item["github_rls_url"], output_path)
|
||||
else:
|
||||
# download from gdrive
|
||||
|
@ -137,7 +132,7 @@ class ModelManager(object):
|
|||
# set scale stats path in config.json
|
||||
config_path = output_config_path
|
||||
config = load_config(config_path)
|
||||
config.external_speaker_embedding_file = output_speakers_path
|
||||
config.d_vector_file = output_speakers_path
|
||||
config.save_json(config_path)
|
||||
return output_model_path, output_config_path, model_item
|
||||
|
||||
|
@ -146,15 +141,20 @@ class ModelManager(object):
|
|||
gdown.download(f"{self.url_prefix}{gdrive_idx}", output=output, quiet=False)
|
||||
|
||||
@staticmethod
|
||||
def _download_zip_file(file_url, output):
|
||||
def _download_zip_file(file_url, output_folder):
|
||||
"""Download the github releases"""
|
||||
# download the file
|
||||
r = requests.get(file_url)
|
||||
# extract the file
|
||||
with zipfile.ZipFile(io.BytesIO(r.content)) as z:
|
||||
z.extractall(output)
|
||||
z.extractall(output_folder)
|
||||
# move the files to the outer path
|
||||
for file_path in z.namelist()[1:]:
|
||||
src_path = os.path.join(output, file_path)
|
||||
dst_path = os.path.join(output, os.path.basename(file_path))
|
||||
src_path = os.path.join(output_folder, file_path)
|
||||
dst_path = os.path.join(output_folder, os.path.basename(file_path))
|
||||
copyfile(src_path, dst_path)
|
||||
# remove the extracted folder
|
||||
rmtree(os.path.join(output_folder, z.namelist()[0]))
|
||||
|
||||
@staticmethod
|
||||
def _check_dict_key(my_dict, key):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# from https://github.com/LiyuanLucasLiu/RAdam
|
||||
# modified from https://github.com/LiyuanLucasLiu/RAdam
|
||||
|
||||
import math
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import pysbd
|
|||
import torch
|
||||
|
||||
from TTS.config import load_config
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.models import setup_model as setup_tts_model
|
||||
from TTS.tts.utils.speakers import SpeakerManager
|
||||
|
||||
# pylint: disable=unused-wildcard-import
|
||||
|
@ -14,7 +14,8 @@ from TTS.tts.utils.speakers import SpeakerManager
|
|||
from TTS.tts.utils.synthesis import synthesis, trim_silence
|
||||
from TTS.tts.utils.text import make_symbols, phonemes, symbols
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input, setup_generator
|
||||
from TTS.vocoder.models import setup_model as setup_vocoder_model
|
||||
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
|
||||
|
||||
|
||||
class Synthesizer(object):
|
||||
|
@ -63,7 +64,7 @@ class Synthesizer(object):
|
|||
self.speaker_manager = None
|
||||
self.num_speakers = 0
|
||||
self.tts_speakers = {}
|
||||
self.speaker_embedding_dim = 0
|
||||
self.d_vector_dim = 0
|
||||
self.seg = self._get_segmenter("en")
|
||||
self.use_cuda = use_cuda
|
||||
|
||||
|
@ -98,9 +99,9 @@ class Synthesizer(object):
|
|||
self.speaker_manager = SpeakerManager(
|
||||
encoder_model_path=self.encoder_checkpoint, encoder_config_path=self.encoder_config
|
||||
)
|
||||
self.speaker_manager.load_x_vectors_file(self.tts_config.get("external_speaker_embedding_file", speaker_file))
|
||||
self.speaker_manager.load_d_vectors_file(self.tts_config.get("d_vector_file", speaker_file))
|
||||
self.num_speakers = self.speaker_manager.num_speakers
|
||||
self.speaker_embedding_dim = self.speaker_manager.x_vector_dim
|
||||
self.d_vector_dim = self.speaker_manager.d_vector_dim
|
||||
|
||||
def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -> None:
|
||||
"""Load the TTS model.
|
||||
|
@ -127,16 +128,11 @@ class Synthesizer(object):
|
|||
|
||||
if self.tts_config.use_speaker_embedding is True:
|
||||
self.tts_speakers_file = (
|
||||
self.tts_speakers_file if self.tts_speakers_file else self.tts_config["external_speaker_embedding_file"]
|
||||
self.tts_speakers_file if self.tts_speakers_file else self.tts_config["d_vector_file"]
|
||||
)
|
||||
self._load_speakers(self.tts_speakers_file)
|
||||
self.tts_config["d_vector_file"] = self.tts_speakers_file
|
||||
|
||||
self.tts_model = setup_model(
|
||||
self.input_size,
|
||||
num_speakers=self.num_speakers,
|
||||
c=self.tts_config,
|
||||
speaker_embedding_dim=self.speaker_embedding_dim,
|
||||
)
|
||||
self.tts_model = setup_tts_model(config=self.tts_config)
|
||||
self.tts_model.load_checkpoint(self.tts_config, tts_checkpoint, eval=True)
|
||||
if use_cuda:
|
||||
self.tts_model.cuda()
|
||||
|
@ -151,7 +147,7 @@ class Synthesizer(object):
|
|||
"""
|
||||
self.vocoder_config = load_config(model_config)
|
||||
self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config.audio)
|
||||
self.vocoder_model = setup_generator(self.vocoder_config)
|
||||
self.vocoder_model = setup_vocoder_model(self.vocoder_config)
|
||||
self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True)
|
||||
if use_cuda:
|
||||
self.vocoder_model.cuda()
|
||||
|
@ -197,9 +193,9 @@ class Synthesizer(object):
|
|||
print(sens)
|
||||
|
||||
if self.tts_speakers_file:
|
||||
# get the speaker embedding from the saved x_vectors.
|
||||
# get the speaker embedding from the saved d_vectors.
|
||||
if speaker_idx and isinstance(speaker_idx, str):
|
||||
speaker_embedding = self.speaker_manager.get_x_vectors_by_speaker(speaker_idx)[0]
|
||||
speaker_embedding = self.tts_model.speaker_manager.get_d_vectors_by_speaker(speaker_idx)[0]
|
||||
elif not speaker_idx and not speaker_wav:
|
||||
raise ValueError(
|
||||
" [!] Look like you use a multi-speaker model. "
|
||||
|
@ -214,15 +210,15 @@ class Synthesizer(object):
|
|||
"Define path for speaker.json if it is a multi-speaker model or remove defined speaker idx. "
|
||||
)
|
||||
|
||||
# compute a new x_vector from the given clip.
|
||||
# compute a new d_vector from the given clip.
|
||||
if speaker_wav is not None:
|
||||
speaker_embedding = self.speaker_manager.compute_x_vector_from_clip(speaker_wav)
|
||||
speaker_embedding = self.speaker_manager.compute_d_vector_from_clip(speaker_wav)
|
||||
|
||||
use_gl = self.vocoder_model is None
|
||||
|
||||
for sen in sens:
|
||||
# synthesize voice
|
||||
waveform, _, _, mel_postnet_spec, _, _ = synthesis(
|
||||
outputs = synthesis(
|
||||
model=self.tts_model,
|
||||
text=sen,
|
||||
CONFIG=self.tts_config,
|
||||
|
@ -230,11 +226,12 @@ class Synthesizer(object):
|
|||
ap=self.ap,
|
||||
speaker_id=None,
|
||||
style_wav=style_wav,
|
||||
truncated=False,
|
||||
enable_eos_bos_chars=self.tts_config.enable_eos_bos_chars,
|
||||
use_griffin_lim=use_gl,
|
||||
speaker_embedding=speaker_embedding,
|
||||
d_vector=speaker_embedding,
|
||||
)
|
||||
waveform = outputs["wav"]
|
||||
mel_postnet_spec = outputs["model_outputs"]
|
||||
if not use_gl:
|
||||
# denormalize tts output based on tts audio config
|
||||
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
import importlib
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
from TTS.utils.training import NoamLR
|
||||
|
||||
|
||||
def is_apex_available():
|
||||
return importlib.util.find_spec("apex") is not None
|
||||
|
||||
|
||||
def setup_torch_training_env(cudnn_enable, cudnn_benchmark):
|
||||
num_gpus = torch.cuda.device_count()
|
||||
if num_gpus > 1:
|
||||
raise RuntimeError(
|
||||
f" [!] {num_gpus} active GPUs. Define the target GPU by `CUDA_VISIBLE_DEVICES`. For multi-gpu training use `TTS/bin/distribute.py`."
|
||||
)
|
||||
torch.backends.cudnn.enabled = cudnn_enable
|
||||
torch.backends.cudnn.benchmark = cudnn_benchmark
|
||||
torch.manual_seed(54321)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
print(" > Using CUDA: ", use_cuda)
|
||||
print(" > Number of GPUs: ", num_gpus)
|
||||
return use_cuda, num_gpus
|
||||
|
||||
|
||||
def get_scheduler(
|
||||
lr_scheduler: str, lr_scheduler_params: Dict, optimizer: torch.optim.Optimizer
|
||||
) -> torch.optim.lr_scheduler._LRScheduler: # pylint: disable=protected-access
|
||||
"""Find, initialize and return a scheduler.
|
||||
|
||||
Args:
|
||||
lr_scheduler (str): Scheduler name.
|
||||
lr_scheduler_params (Dict): Scheduler parameters.
|
||||
optimizer (torch.optim.Optimizer): Optimizer to pass to the scheduler.
|
||||
|
||||
Returns:
|
||||
torch.optim.lr_scheduler._LRScheduler: Functional scheduler.
|
||||
"""
|
||||
if lr_scheduler is None:
|
||||
return None
|
||||
if lr_scheduler.lower() == "noamlr":
|
||||
scheduler = NoamLR
|
||||
else:
|
||||
scheduler = getattr(torch.optim.lr_scheduler, lr_scheduler)
|
||||
return scheduler(optimizer, **lr_scheduler_params)
|
||||
|
||||
|
||||
def get_optimizer(
|
||||
optimizer_name: str, optimizer_params: dict, lr: float, model: torch.nn.Module
|
||||
) -> torch.optim.Optimizer:
|
||||
"""Find, initialize and return a optimizer.
|
||||
|
||||
Args:
|
||||
optimizer_name (str): Optimizer name.
|
||||
optimizer_params (dict): Optimizer parameters.
|
||||
lr (float): Initial learning rate.
|
||||
model (torch.nn.Module): Model to pass to the optimizer.
|
||||
|
||||
Returns:
|
||||
torch.optim.Optimizer: Functional optimizer.
|
||||
"""
|
||||
if optimizer_name.lower() == "radam":
|
||||
module = importlib.import_module("TTS.utils.radam")
|
||||
optimizer = getattr(module, "RAdam")
|
||||
else:
|
||||
optimizer = getattr(torch.optim, optimizer_name)
|
||||
return optimizer(model.parameters(), lr=lr, **optimizer_params)
|
|
@ -2,17 +2,6 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
|
||||
def setup_torch_training_env(cudnn_enable, cudnn_benchmark):
|
||||
torch.backends.cudnn.enabled = cudnn_enable
|
||||
torch.backends.cudnn.benchmark = cudnn_benchmark
|
||||
torch.manual_seed(54321)
|
||||
use_cuda = torch.cuda.is_available()
|
||||
num_gpus = torch.cuda.device_count()
|
||||
print(" > Using CUDA: ", use_cuda)
|
||||
print(" > Number of GPUs: ", num_gpus)
|
||||
return use_cuda, num_gpus
|
||||
|
||||
|
||||
def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
|
||||
r"""Check model gradient against unexpected jumps and failures"""
|
||||
skip_flag = False
|
||||
|
@ -41,46 +30,6 @@ def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
|
|||
return grad_norm, skip_flag
|
||||
|
||||
|
||||
def lr_decay(init_lr, global_step, warmup_steps):
|
||||
r"""from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py"""
|
||||
warmup_steps = float(warmup_steps)
|
||||
step = global_step + 1.0
|
||||
lr = init_lr * warmup_steps ** 0.5 * np.minimum(step * warmup_steps ** -1.5, step ** -0.5)
|
||||
return lr
|
||||
|
||||
|
||||
def adam_weight_decay(optimizer):
|
||||
"""
|
||||
Custom weight decay operation, not effecting grad values.
|
||||
"""
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
current_lr = group["lr"]
|
||||
weight_decay = group["weight_decay"]
|
||||
factor = -weight_decay * group["lr"]
|
||||
param.data = param.data.add(param.data, alpha=factor)
|
||||
return optimizer, current_lr
|
||||
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}):
|
||||
"""
|
||||
Skip biases, BatchNorm parameters, rnns.
|
||||
and attention projection layer v
|
||||
"""
|
||||
decay = []
|
||||
no_decay = []
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
|
||||
if len(param.shape) == 1 or any((skip_name in name for skip_name in skip_list)):
|
||||
no_decay.append(param)
|
||||
else:
|
||||
decay.append(param)
|
||||
return [{"params": no_decay, "weight_decay": 0.0}, {"params": decay, "weight_decay": weight_decay}]
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
|
||||
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
|
||||
|
@ -107,3 +56,31 @@ def gradual_training_scheduler(global_step, config):
|
|||
if global_step * num_gpus >= values[0]:
|
||||
new_values = values
|
||||
return new_values[1], new_values[2]
|
||||
|
||||
|
||||
def lr_decay(init_lr, global_step, warmup_steps):
|
||||
r"""from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py
|
||||
It is only being used by the Speaker Encoder trainer."""
|
||||
warmup_steps = float(warmup_steps)
|
||||
step = global_step + 1.0
|
||||
lr = init_lr * warmup_steps ** 0.5 * np.minimum(step * warmup_steps ** -1.5, step ** -0.5)
|
||||
return lr
|
||||
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}):
|
||||
"""
|
||||
Skip biases, BatchNorm parameters, rnns.
|
||||
and attention projection layer v
|
||||
"""
|
||||
decay = []
|
||||
no_decay = []
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
|
||||
if len(param.shape) == 1 or any((skip_name in name for skip_name in skip_list)):
|
||||
no_decay.append(param)
|
||||
else:
|
||||
decay.append(param)
|
||||
return [{"params": no_decay, "weight_decay": 0.0}, {"params": decay, "weight_decay": weight_decay}]
|
||||
|
|
|
@ -14,7 +14,7 @@ class FullbandMelganConfig(BaseGANVocoderConfig):
|
|||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `melgan`.
|
||||
Model name used for selecting the right model at initialization. Defaults to `fullband_melgan`.
|
||||
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
|
||||
'melgan_multiscale_discriminator`.
|
||||
discriminator_model_params (dict): The discriminator model parameters. Defaults to
|
||||
|
@ -62,7 +62,7 @@ class FullbandMelganConfig(BaseGANVocoderConfig):
|
|||
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
"""
|
||||
|
||||
model: str = "melgan"
|
||||
model: str = "fullband_melgan"
|
||||
|
||||
# Model specific params
|
||||
discriminator_model: str = "melgan_multiscale_discriminator"
|
||||
|
|
|
@ -14,7 +14,7 @@ class MultibandMelganConfig(BaseGANVocoderConfig):
|
|||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `melgan`.
|
||||
Model name used for selecting the right model at initialization. Defaults to `multiband_melgan`.
|
||||
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
|
||||
'melgan_multiscale_discriminator`.
|
||||
discriminator_model_params (dict): The discriminator model parameters. Defaults to
|
||||
|
|
|
@ -9,7 +9,7 @@ class ParallelWaveganConfig(BaseGANVocoderConfig):
|
|||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right configuration at initialization. Defaults to `parallel_wavegan`.
|
||||
Model name used for selecting the right configuration at initialization. Defaults to `gan`.
|
||||
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
|
||||
'parallel_wavegan_discriminator`.
|
||||
discriminator_model_params (dict): The discriminator model kwargs. Defaults to
|
||||
|
|
|
@ -34,6 +34,10 @@ class BaseVocoderConfig(BaseTrainingConfig):
|
|||
Number of training epochs to. Defaults to 10000.
|
||||
wd (float):
|
||||
Weight decay.
|
||||
optimizer (torch.optim.Optimizer):
|
||||
Optimizer used for the training. Defaults to `AdamW`.
|
||||
optimizer_params (dict):
|
||||
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
||||
"""
|
||||
|
||||
audio: BaseAudioConfig = field(default_factory=BaseAudioConfig)
|
||||
|
@ -50,6 +54,8 @@ class BaseVocoderConfig(BaseTrainingConfig):
|
|||
# OPTIMIZER
|
||||
epochs: int = 10000 # total number of epochs to train.
|
||||
wd: float = 0.0 # Weight decay weight.
|
||||
optimizer: str = "AdamW"
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0})
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -96,20 +102,13 @@ class BaseGANVocoderConfig(BaseVocoderConfig):
|
|||
}`
|
||||
target_loss (str):
|
||||
Target loss name that defines the quality of the model. Defaults to `avg_G_loss`.
|
||||
gen_clip_grad (float):
|
||||
Gradient clipping threshold for the generator model. Any value less than 0 disables clipping.
|
||||
Defaults to -1.
|
||||
disc_clip_grad (float):
|
||||
Gradient clipping threshold for the discriminator model. Any value less than 0 disables clipping.
|
||||
Defaults to -1.
|
||||
grad_clip (list):
|
||||
A list of gradient clipping theresholds for each optimizer. Any value less than 0 disables clipping.
|
||||
Defaults to [5, 5].
|
||||
lr_gen (float):
|
||||
Generator model initial learning rate. Defaults to 0.0002.
|
||||
lr_disc (float):
|
||||
Discriminator model initial learning rate. Defaults to 0.0002.
|
||||
optimizer (torch.optim.Optimizer):
|
||||
Optimizer used for the training. Defaults to `AdamW`.
|
||||
optimizer_params (dict):
|
||||
Optimizer kwargs. Defaults to `{"betas": [0.8, 0.99], "weight_decay": 0.0}`
|
||||
lr_scheduler_gen (torch.optim.Scheduler):
|
||||
Learning rate scheduler for the generator. Defaults to `ExponentialLR`.
|
||||
lr_scheduler_gen_params (dict):
|
||||
|
@ -127,6 +126,8 @@ class BaseGANVocoderConfig(BaseVocoderConfig):
|
|||
Enabling it results in slower iterations but faster convergance in some cases. Defaults to False.
|
||||
"""
|
||||
|
||||
model: str = "gan"
|
||||
|
||||
# LOSS PARAMETERS
|
||||
use_stft_loss: bool = True
|
||||
use_subband_stft_loss: bool = True
|
||||
|
@ -164,15 +165,12 @@ class BaseGANVocoderConfig(BaseVocoderConfig):
|
|||
}
|
||||
)
|
||||
|
||||
target_loss: str = "avg_G_loss" # loss value to pick the best model to save after each epoch
|
||||
target_loss: str = "loss_0" # loss value to pick the best model to save after each epoch
|
||||
|
||||
# optimizer
|
||||
gen_clip_grad: float = -1 # Generator gradient clipping threshold. Apply gradient clipping if > 0
|
||||
disc_clip_grad: float = -1 # Discriminator gradient clipping threshold.
|
||||
grad_clip: float = field(default_factory=lambda: [5, 5])
|
||||
lr_gen: float = 0.0002 # Initial learning rate.
|
||||
lr_disc: float = 0.0002 # Initial learning rate.
|
||||
optimizer: str = "AdamW"
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "weight_decay": 0.0})
|
||||
lr_scheduler_gen: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
|
||||
lr_scheduler_disc: str = "ExponentialLR" # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
from dataclasses import dataclass, field
|
||||
|
||||
from TTS.vocoder.configs.shared_configs import BaseGANVocoderConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnivnetConfig(BaseGANVocoderConfig):
|
||||
"""Defines parameters for UnivNet vocoder.
|
||||
|
||||
Example:
|
||||
|
||||
>>> from TTS.vocoder.configs import UnivNetConfig
|
||||
>>> config = UnivNetConfig()
|
||||
|
||||
Args:
|
||||
model (str):
|
||||
Model name used for selecting the right model at initialization. Defaults to `UnivNet`.
|
||||
discriminator_model (str): One of the discriminators from `TTS.vocoder.models.*_discriminator`. Defaults to
|
||||
'UnivNet_discriminator`.
|
||||
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
|
||||
considered as a generator too. Defaults to `UnivNet_generator`.
|
||||
generator_model_params (dict): Parameters of the generator model. Defaults to
|
||||
`
|
||||
{
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
"n_fft": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_mels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": None,
|
||||
}
|
||||
`
|
||||
batch_size (int):
|
||||
Batch size used at training. Larger values use more memory. Defaults to 32.
|
||||
seq_len (int):
|
||||
Audio segment length used at training. Larger values use more memory. Defaults to 8192.
|
||||
pad_short (int):
|
||||
Additional padding applied to the audio samples shorter than `seq_len`. Defaults to 0.
|
||||
use_noise_augment (bool):
|
||||
enable / disable random noise added to the input waveform. The noise is added after computing the
|
||||
features. Defaults to True.
|
||||
use_cache (bool):
|
||||
enable / disable in memory caching of the computed features. It can cause OOM error if the system RAM is
|
||||
not large enough. Defaults to True.
|
||||
use_stft_loss (bool):
|
||||
enable / disable use of STFT loss originally used by ParallelWaveGAN model. Defaults to True.
|
||||
use_subband_stft (bool):
|
||||
enable / disable use of subband loss computation originally used by MultiBandMelgan model. Defaults to True.
|
||||
use_mse_gan_loss (bool):
|
||||
enable / disable using Mean Squeare Error GAN loss. Defaults to True.
|
||||
use_hinge_gan_loss (bool):
|
||||
enable / disable using Hinge GAN loss. You should choose either Hinge or MSE loss for training GAN models.
|
||||
Defaults to False.
|
||||
use_feat_match_loss (bool):
|
||||
enable / disable using Feature Matching loss originally used by MelGAN model. Defaults to True.
|
||||
use_l1_spec_loss (bool):
|
||||
enable / disable using L1 spectrogram loss originally used by univnet model. Defaults to False.
|
||||
stft_loss_params (dict):
|
||||
STFT loss parameters. Default to
|
||||
`{
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240]
|
||||
}`
|
||||
l1_spec_loss_params (dict):
|
||||
L1 spectrogram loss parameters. Default to
|
||||
`{
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
"n_fft": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_mels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": None,
|
||||
}`
|
||||
stft_loss_weight (float): STFT loss weight that multiplies the computed loss before summing up the total
|
||||
model loss. Defaults to 0.5.
|
||||
subband_stft_loss_weight (float):
|
||||
Subband STFT loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
mse_G_loss_weight (float):
|
||||
MSE generator loss weight that multiplies the computed loss before summing up the total loss. faults to 2.5.
|
||||
hinge_G_loss_weight (float):
|
||||
Hinge generator loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
feat_match_loss_weight (float):
|
||||
Feature matching loss weight that multiplies the computed loss before summing up the total loss. faults to 108.
|
||||
l1_spec_loss_weight (float):
|
||||
L1 spectrogram loss weight that multiplies the computed loss before summing up the total loss. Defaults to 0.
|
||||
"""
|
||||
|
||||
model: str = "univnet"
|
||||
batch_size: int = 32
|
||||
# model specific params
|
||||
discriminator_model: str = "univnet_discriminator"
|
||||
generator_model: str = "univnet_generator"
|
||||
generator_model_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"in_channels": 64,
|
||||
"out_channels": 1,
|
||||
"hidden_channels": 32,
|
||||
"cond_channels": 80,
|
||||
"upsample_factors": [8, 8, 4],
|
||||
"lvc_layers_each_block": 4,
|
||||
"lvc_kernel_size": 3,
|
||||
"kpnet_hidden_channels": 64,
|
||||
"kpnet_conv_size": 3,
|
||||
"dropout": 0.0,
|
||||
}
|
||||
)
|
||||
|
||||
# LOSS PARAMETERS - overrides
|
||||
use_stft_loss: bool = True
|
||||
use_subband_stft_loss: bool = False
|
||||
use_mse_gan_loss: bool = True
|
||||
use_hinge_gan_loss: bool = False
|
||||
use_feat_match_loss: bool = False # requires MelGAN Discriminators (MelGAN and univnet)
|
||||
use_l1_spec_loss: bool = False
|
||||
|
||||
# loss weights - overrides
|
||||
stft_loss_weight: float = 2.5
|
||||
stft_loss_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"n_ffts": [1024, 2048, 512],
|
||||
"hop_lengths": [120, 240, 50],
|
||||
"win_lengths": [600, 1200, 240],
|
||||
}
|
||||
)
|
||||
subband_stft_loss_weight: float = 0
|
||||
mse_G_loss_weight: float = 1
|
||||
hinge_G_loss_weight: float = 0
|
||||
feat_match_loss_weight: float = 0
|
||||
l1_spec_loss_weight: float = 0
|
||||
l1_spec_loss_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"use_mel": True,
|
||||
"sample_rate": 22050,
|
||||
"n_fft": 1024,
|
||||
"hop_length": 256,
|
||||
"win_length": 1024,
|
||||
"n_mels": 80,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": None,
|
||||
}
|
||||
)
|
||||
|
||||
# optimizer parameters
|
||||
lr_gen: float = 1e-4 # Initial learning rate.
|
||||
lr_disc: float = 1e-4 # Initial learning rate.
|
||||
lr_scheduler_gen: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
# lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
|
||||
lr_scheduler_disc: str = None # one of the schedulers from https:#pytorch.org/docs/stable/optim.html
|
||||
# lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999, "last_epoch": -1})
|
||||
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.5, 0.9], "weight_decay": 0.0})
|
||||
steps_to_start_discriminator: int = 200000
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.generator_model_params["cond_channels"] = self.audio.num_mels
|
|
@ -1,6 +1,7 @@
|
|||
from dataclasses import dataclass, field
|
||||
|
||||
from TTS.vocoder.configs.shared_configs import BaseVocoderConfig
|
||||
from TTS.vocoder.models.wavegrad import WavegradArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -16,19 +17,7 @@ class WavegradConfig(BaseVocoderConfig):
|
|||
Model name used for selecting the right model at initialization. Defaults to `wavegrad`.
|
||||
generator_model (str): One of the generators from TTS.vocoder.models.*`. Every other non-GAN vocoder model is
|
||||
considered as a generator too. Defaults to `wavegrad`.
|
||||
model_params (dict):
|
||||
WaveGrad kwargs. Defaults to
|
||||
`
|
||||
{
|
||||
"use_weight_norm": True,
|
||||
"y_conv_channels": 32,
|
||||
"x_conv_channels": 768,
|
||||
"ublock_out_channels": [512, 512, 256, 128, 128],
|
||||
"dblock_out_channels": [128, 128, 256, 512],
|
||||
"upsample_factors": [4, 4, 4, 2, 2],
|
||||
"upsample_dilations": [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]],
|
||||
}
|
||||
`
|
||||
model_params (WavegradArgs): Model parameters. Check `WavegradArgs` for default values.
|
||||
target_loss (str):
|
||||
Target loss name that defines the quality of the model. Defaults to `avg_wavegrad_loss`.
|
||||
epochs (int):
|
||||
|
@ -70,18 +59,8 @@ class WavegradConfig(BaseVocoderConfig):
|
|||
model: str = "wavegrad"
|
||||
# Model specific params
|
||||
generator_model: str = "wavegrad"
|
||||
model_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"use_weight_norm": True,
|
||||
"y_conv_channels": 32,
|
||||
"x_conv_channels": 768,
|
||||
"ublock_out_channels": [512, 512, 256, 128, 128],
|
||||
"dblock_out_channels": [128, 128, 256, 512],
|
||||
"upsample_factors": [4, 4, 4, 2, 2],
|
||||
"upsample_dilations": [[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]],
|
||||
}
|
||||
)
|
||||
target_loss: str = "avg_wavegrad_loss" # loss value to pick the best model to save after each epoch
|
||||
model_params: WavegradArgs = field(default_factory=WavegradArgs)
|
||||
target_loss: str = "loss" # loss value to pick the best model to save after each epoch
|
||||
|
||||
# Training - overrides
|
||||
epochs: int = 10000
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from dataclasses import dataclass, field
|
||||
|
||||
from TTS.vocoder.configs.shared_configs import BaseVocoderConfig
|
||||
from TTS.vocoder.models.wavernn import WavernnArgs
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -47,9 +48,7 @@ class WavernnConfig(BaseVocoderConfig):
|
|||
Batch size used at training. Larger values use more memory. Defaults to 256.
|
||||
seq_len (int):
|
||||
Audio segment length used at training. Larger values use more memory. Defaults to 1280.
|
||||
padding (int):
|
||||
Padding applied to the input feature frames against the convolution layers of the feature network.
|
||||
Defaults to 2.
|
||||
|
||||
use_noise_augment (bool):
|
||||
enable / disable random noise added to the input waveform. The noise is added after computing the
|
||||
features. Defaults to True.
|
||||
|
@ -60,7 +59,7 @@ class WavernnConfig(BaseVocoderConfig):
|
|||
enable / disable mixed precision training. Default is True.
|
||||
eval_split_size (int):
|
||||
Number of samples used for evalutaion. Defaults to 50.
|
||||
test_every_epoch (int):
|
||||
num_epochs_before_test (int):
|
||||
Number of epochs waited to run the next evalution. Since inference takes some time, it is better to
|
||||
wait some number of epochs not ot waste training time. Defaults to 10.
|
||||
grad_clip (float):
|
||||
|
@ -76,21 +75,8 @@ class WavernnConfig(BaseVocoderConfig):
|
|||
model: str = "wavernn"
|
||||
|
||||
# Model specific params
|
||||
mode: str = "mold" # mold [string], gauss [string], bits [int]
|
||||
mulaw: bool = True # apply mulaw if mode is bits
|
||||
generator_model: str = "WaveRNN"
|
||||
wavernn_model_params: dict = field(
|
||||
default_factory=lambda: {
|
||||
"rnn_dims": 512,
|
||||
"fc_dims": 512,
|
||||
"compute_dims": 128,
|
||||
"res_out_dims": 128,
|
||||
"num_res_blocks": 10,
|
||||
"use_aux_net": True,
|
||||
"use_upsample_net": True,
|
||||
"upsample_factors": [4, 8, 8], # this needs to correctly factorise hop_length
|
||||
}
|
||||
)
|
||||
model_params: WavernnArgs = field(default_factory=WavernnArgs)
|
||||
target_loss: str = "loss"
|
||||
|
||||
# Inference
|
||||
batched: bool = True
|
||||
|
@ -101,12 +87,13 @@ class WavernnConfig(BaseVocoderConfig):
|
|||
epochs: int = 10000
|
||||
batch_size: int = 256
|
||||
seq_len: int = 1280
|
||||
padding: int = 2
|
||||
use_noise_augment: bool = False
|
||||
use_cache: bool = True
|
||||
mixed_precision: bool = True
|
||||
eval_split_size: int = 50
|
||||
test_every_epochs: int = 10 # number of epochs to wait until the next test run (synthesizing a full audio clip).
|
||||
num_epochs_before_test: int = (
|
||||
10 # number of epochs to wait until the next test run (synthesizing a full audio clip).
|
||||
)
|
||||
|
||||
# optimizer overrides
|
||||
grad_clip: float = 4.0
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
from typing import List
|
||||
|
||||
from coqpit import Coqpit
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||
|
||||
|
||||
def setup_dataset(config: Coqpit, ap: AudioProcessor, is_eval: bool, data_items: List, verbose: bool) -> Dataset:
|
||||
if config.model.lower() in "gan":
|
||||
dataset = GANDataset(
|
||||
ap=ap,
|
||||
items=data_items,
|
||||
seq_len=config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=config.pad_short,
|
||||
conv_pad=config.conv_pad,
|
||||
return_pairs=config.diff_samples_for_G_and_D if "diff_samples_for_G_and_D" in config else False,
|
||||
is_training=not is_eval,
|
||||
return_segments=not is_eval,
|
||||
use_noise_augment=config.use_noise_augment,
|
||||
use_cache=config.use_cache,
|
||||
verbose=verbose,
|
||||
)
|
||||
dataset.shuffle_mapping()
|
||||
elif config.model.lower() == "wavegrad":
|
||||
dataset = WaveGradDataset(
|
||||
ap=ap,
|
||||
items=data_items,
|
||||
seq_len=config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=config.pad_short,
|
||||
conv_pad=config.conv_pad,
|
||||
is_training=not is_eval,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
use_cache=config.use_cache,
|
||||
verbose=verbose,
|
||||
)
|
||||
elif config.model.lower() == "wavernn":
|
||||
dataset = WaveRNNDataset(
|
||||
ap=ap,
|
||||
items=data_items,
|
||||
seq_len=config.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad=config.model_params.pad,
|
||||
mode=config.model_params.mode,
|
||||
mulaw=config.model_params.mulaw,
|
||||
is_training=not is_eval,
|
||||
verbose=verbose,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f" [!] Dataset for model {config.model.lower()} cannot be found.")
|
||||
return dataset
|
|
@ -3,10 +3,21 @@ import os
|
|||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from coqpit import Coqpit
|
||||
from tqdm import tqdm
|
||||
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
|
||||
def preprocess_wav_files(out_path, config, ap):
|
||||
|
||||
def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor):
|
||||
"""Process wav and compute mel and quantized wave signal.
|
||||
It is mainly used by WaveRNN dataloader.
|
||||
|
||||
Args:
|
||||
out_path (str): Parent folder path to save the files.
|
||||
config (Coqpit): Model config.
|
||||
ap (AudioProcessor): Audio processor.
|
||||
"""
|
||||
os.makedirs(os.path.join(out_path, "quant"), exist_ok=True)
|
||||
os.makedirs(os.path.join(out_path, "mel"), exist_ok=True)
|
||||
wav_files = find_wav_files(config.data_path)
|
||||
|
@ -18,7 +29,9 @@ def preprocess_wav_files(out_path, config, ap):
|
|||
mel = ap.melspectrogram(y)
|
||||
np.save(mel_path, mel)
|
||||
if isinstance(config.mode, int):
|
||||
quant = ap.mulaw_encode(y, qc=config.mode) if config.mulaw else ap.quantize(y, bits=config.mode)
|
||||
quant = (
|
||||
ap.mulaw_encode(y, qc=config.mode) if config.model_params.mulaw else ap.quantize(y, bits=config.mode)
|
||||
)
|
||||
np.save(quant_path, quant)
|
||||
|
||||
|
||||
|
|
|
@ -136,4 +136,4 @@ class WaveGradDataset(Dataset):
|
|||
mels[idx, :, : mel.shape[1]] = mel
|
||||
audios[idx, : audio.shape[0]] = audio
|
||||
|
||||
return mels, audios
|
||||
return audios, mels
|
||||
|
|
|
@ -10,16 +10,7 @@ class WaveRNNDataset(Dataset):
|
|||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ap,
|
||||
items,
|
||||
seq_len,
|
||||
hop_len,
|
||||
pad,
|
||||
mode,
|
||||
mulaw,
|
||||
is_training=True,
|
||||
verbose=False,
|
||||
self, ap, items, seq_len, hop_len, pad, mode, mulaw, is_training=True, verbose=False, return_segments=True
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
@ -34,6 +25,7 @@ class WaveRNNDataset(Dataset):
|
|||
self.mulaw = mulaw
|
||||
self.is_training = is_training
|
||||
self.verbose = verbose
|
||||
self.return_segments = return_segments
|
||||
|
||||
assert self.seq_len % self.hop_len == 0
|
||||
|
||||
|
@ -44,6 +36,16 @@ class WaveRNNDataset(Dataset):
|
|||
item = self.load_item(index)
|
||||
return item
|
||||
|
||||
def load_test_samples(self, num_samples):
|
||||
samples = []
|
||||
return_segments = self.return_segments
|
||||
self.return_segments = False
|
||||
for idx in range(num_samples):
|
||||
mel, audio, _ = self.load_item(idx)
|
||||
samples.append([mel, audio])
|
||||
self.return_segments = return_segments
|
||||
return samples
|
||||
|
||||
def load_item(self, index):
|
||||
"""
|
||||
load (audio, feat) couple if feature_path is set
|
||||
|
@ -53,7 +55,10 @@ class WaveRNNDataset(Dataset):
|
|||
|
||||
wavpath = self.item_list[index]
|
||||
audio = self.ap.load_wav(wavpath)
|
||||
min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len)
|
||||
if self.return_segments:
|
||||
min_audio_len = 2 * self.seq_len + (2 * self.pad * self.hop_len)
|
||||
else:
|
||||
min_audio_len = audio.shape[0] + (2 * self.pad * self.hop_len)
|
||||
if audio.shape[0] < min_audio_len:
|
||||
print(" [!] Instance is too short! : {}".format(wavpath))
|
||||
audio = np.pad(audio, [0, min_audio_len - audio.shape[0] + self.hop_len])
|
||||
|
|
|
@ -1,83 +1,11 @@
|
|||
import librosa
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||
"""TODO: Merge this with audio.py"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
pad_wav=False,
|
||||
window="hann_window",
|
||||
sample_rate=None,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
n_mels=80,
|
||||
use_mel=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.pad_wav = pad_wav
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_fmin = mel_fmin
|
||||
self.mel_fmax = mel_fmax
|
||||
self.n_mels = n_mels
|
||||
self.use_mel = use_mel
|
||||
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
||||
self.mel_basis = None
|
||||
if use_mel:
|
||||
self._build_mel_basis()
|
||||
|
||||
def __call__(self, x):
|
||||
"""Compute spectrogram frames by torch based stft.
|
||||
|
||||
Args:
|
||||
x (Tensor): input waveform
|
||||
|
||||
Returns:
|
||||
Tensor: spectrogram frames.
|
||||
|
||||
Shapes:
|
||||
x: [B x T] or [B x 1 x T]
|
||||
"""
|
||||
if x.ndim == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if self.pad_wav:
|
||||
padding = int((self.n_fft - self.hop_length) / 2)
|
||||
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
|
||||
# B x D x T x 2
|
||||
o = torch.stft(
|
||||
x.squeeze(1),
|
||||
self.n_fft,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
self.window,
|
||||
center=True,
|
||||
pad_mode="reflect", # compatible with audio.py
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=False,
|
||||
)
|
||||
M = o[:, :, :, 0]
|
||||
P = o[:, :, :, 1]
|
||||
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
|
||||
if self.use_mel:
|
||||
S = torch.matmul(self.mel_basis.to(x), S)
|
||||
return S
|
||||
|
||||
def _build_mel_basis(self):
|
||||
mel_basis = librosa.filters.mel(
|
||||
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
|
||||
)
|
||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||
|
||||
from TTS.utils.audio import TorchSTFT
|
||||
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
||||
|
||||
#################################
|
||||
# GENERATOR LOSSES
|
||||
|
@ -271,7 +199,7 @@ def _apply_D_loss(scores_fake, scores_real, loss_func):
|
|||
loss += total_loss
|
||||
real_loss += real_loss
|
||||
fake_loss += fake_loss
|
||||
# normalize loss values with number of scales
|
||||
# normalize loss values with number of scales (discriminators)
|
||||
loss /= len(scores_fake)
|
||||
real_loss /= len(scores_real)
|
||||
fake_loss /= len(scores_fake)
|
||||
|
@ -374,7 +302,7 @@ class GeneratorLoss(nn.Module):
|
|||
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
|
||||
return_dict["G_feat_match_loss"] = feat_match_loss
|
||||
adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss
|
||||
return_dict["G_loss"] = gen_loss + adv_loss
|
||||
return_dict["loss"] = gen_loss + adv_loss
|
||||
return_dict["G_gen_loss"] = gen_loss
|
||||
return_dict["G_adv_loss"] = adv_loss
|
||||
return return_dict
|
||||
|
@ -419,5 +347,22 @@ class DiscriminatorLoss(nn.Module):
|
|||
return_dict["D_hinge_gan_fake_loss"] = hinge_D_fake_loss
|
||||
loss += hinge_D_loss
|
||||
|
||||
return_dict["D_loss"] = loss
|
||||
return_dict["loss"] = loss
|
||||
return return_dict
|
||||
|
||||
|
||||
class WaveRNNLoss(nn.Module):
|
||||
def __init__(self, wave_rnn_mode: Union[str, int]):
|
||||
super().__init__()
|
||||
if wave_rnn_mode == "mold":
|
||||
self.loss_func = discretized_mix_logistic_loss
|
||||
elif wave_rnn_mode == "gauss":
|
||||
self.loss_func = gaussian_loss
|
||||
elif isinstance(wave_rnn_mode, int):
|
||||
self.loss_func = torch.nn.CrossEntropyLoss()
|
||||
else:
|
||||
raise ValueError(" [!] Unknown mode for Wavernn.")
|
||||
|
||||
def forward(self, y_hat, y) -> Dict:
|
||||
loss = self.loss_func(y_hat, y)
|
||||
return {"loss": loss}
|
||||
|
|
|
@ -0,0 +1,198 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class KernelPredictor(torch.nn.Module):
|
||||
"""Kernel predictor for the location-variable convolutions"""
|
||||
|
||||
def __init__( # pylint: disable=dangerous-default-value
|
||||
self,
|
||||
cond_channels,
|
||||
conv_in_channels,
|
||||
conv_out_channels,
|
||||
conv_layers,
|
||||
conv_kernel_size=3,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0,
|
||||
kpnet_nonlinear_activation="LeakyReLU",
|
||||
kpnet_nonlinear_activation_params={"negative_slope": 0.1},
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
cond_channels (int): number of channel for the conditioning sequence,
|
||||
conv_in_channels (int): number of channel for the input sequence,
|
||||
conv_out_channels (int): number of channel for the output sequence,
|
||||
conv_layers (int):
|
||||
kpnet_
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.conv_in_channels = conv_in_channels
|
||||
self.conv_out_channels = conv_out_channels
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.conv_layers = conv_layers
|
||||
|
||||
l_w = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers
|
||||
l_b = conv_out_channels * conv_layers
|
||||
|
||||
padding = (kpnet_conv_size - 1) // 2
|
||||
self.input_conv = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=(5 - 1) // 2, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
self.residual_conv = torch.nn.Sequential(
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Dropout(kpnet_dropout),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True),
|
||||
getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params),
|
||||
)
|
||||
|
||||
self.kernel_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_w, kpnet_conv_size, padding=padding, bias=True)
|
||||
self.bias_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_b, kpnet_conv_size, padding=padding, bias=True)
|
||||
|
||||
def forward(self, c):
|
||||
"""
|
||||
Args:
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
Returns:
|
||||
"""
|
||||
batch, _, cond_length = c.shape
|
||||
|
||||
c = self.input_conv(c)
|
||||
c = c + self.residual_conv(c)
|
||||
k = self.kernel_conv(c)
|
||||
b = self.bias_conv(c)
|
||||
|
||||
kernels = k.contiguous().view(
|
||||
batch, self.conv_layers, self.conv_in_channels, self.conv_out_channels, self.conv_kernel_size, cond_length
|
||||
)
|
||||
bias = b.contiguous().view(batch, self.conv_layers, self.conv_out_channels, cond_length)
|
||||
return kernels, bias
|
||||
|
||||
|
||||
class LVCBlock(torch.nn.Module):
|
||||
"""the location-variable convolutions"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
cond_channels,
|
||||
upsample_ratio,
|
||||
conv_layers=4,
|
||||
conv_kernel_size=3,
|
||||
cond_hop_length=256,
|
||||
kpnet_hidden_channels=64,
|
||||
kpnet_conv_size=3,
|
||||
kpnet_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.cond_hop_length = cond_hop_length
|
||||
self.conv_layers = conv_layers
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.convs = torch.nn.ModuleList()
|
||||
|
||||
self.upsample = torch.nn.ConvTranspose1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=upsample_ratio * 2,
|
||||
stride=upsample_ratio,
|
||||
padding=upsample_ratio // 2 + upsample_ratio % 2,
|
||||
output_padding=upsample_ratio % 2,
|
||||
)
|
||||
|
||||
self.kernel_predictor = KernelPredictor(
|
||||
cond_channels=cond_channels,
|
||||
conv_in_channels=in_channels,
|
||||
conv_out_channels=2 * in_channels,
|
||||
conv_layers=conv_layers,
|
||||
conv_kernel_size=conv_kernel_size,
|
||||
kpnet_hidden_channels=kpnet_hidden_channels,
|
||||
kpnet_conv_size=kpnet_conv_size,
|
||||
kpnet_dropout=kpnet_dropout,
|
||||
)
|
||||
|
||||
for i in range(conv_layers):
|
||||
padding = (3 ** i) * int((conv_kernel_size - 1) / 2)
|
||||
conv = torch.nn.Conv1d(
|
||||
in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3 ** i
|
||||
)
|
||||
|
||||
self.convs.append(conv)
|
||||
|
||||
def forward(self, x, c):
|
||||
"""forward propagation of the location-variable convolutions.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length)
|
||||
c (Tensor): the conditioning sequence (batch, cond_channels, cond_length)
|
||||
|
||||
Returns:
|
||||
Tensor: the output sequence (batch, in_channels, in_length)
|
||||
"""
|
||||
in_channels = x.shape[1]
|
||||
kernels, bias = self.kernel_predictor(c)
|
||||
|
||||
x = F.leaky_relu(x, 0.2)
|
||||
x = self.upsample(x)
|
||||
|
||||
for i in range(self.conv_layers):
|
||||
y = F.leaky_relu(x, 0.2)
|
||||
y = self.convs[i](y)
|
||||
y = F.leaky_relu(y, 0.2)
|
||||
|
||||
k = kernels[:, i, :, :, :, :]
|
||||
b = bias[:, i, :, :]
|
||||
y = self.location_variable_convolution(y, k, b, 1, self.cond_hop_length)
|
||||
x = x + torch.sigmoid(y[:, :in_channels, :]) * torch.tanh(y[:, in_channels:, :])
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def location_variable_convolution(x, kernel, bias, dilation, hop_size):
|
||||
"""perform location-variable convolution operation on the input sequence (x) using the local convolution kernl.
|
||||
Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100.
|
||||
Args:
|
||||
x (Tensor): the input sequence (batch, in_channels, in_length).
|
||||
kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length)
|
||||
bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length)
|
||||
dilation (int): the dilation of convolution.
|
||||
hop_size (int): the hop_size of the conditioning sequence.
|
||||
Returns:
|
||||
(Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length).
|
||||
"""
|
||||
batch, _, in_length = x.shape
|
||||
batch, _, out_channels, kernel_size, kernel_length = kernel.shape
|
||||
|
||||
assert in_length == (
|
||||
kernel_length * hop_size
|
||||
), f"length of (x, kernel) is not matched, {in_length} vs {kernel_length * hop_size}"
|
||||
|
||||
padding = dilation * int((kernel_size - 1) / 2)
|
||||
x = F.pad(x, (padding, padding), "constant", 0) # (batch, in_channels, in_length + 2*padding)
|
||||
x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding)
|
||||
|
||||
if hop_size < dilation:
|
||||
x = F.pad(x, (0, dilation), "constant", 0)
|
||||
x = x.unfold(
|
||||
3, dilation, dilation
|
||||
) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation)
|
||||
x = x[:, :, :, :, :hop_size]
|
||||
x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation)
|
||||
x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size)
|
||||
|
||||
o = torch.einsum("bildsk,biokl->bolsd", x, kernel)
|
||||
o = o + bias.unsqueeze(-1).unsqueeze(-1)
|
||||
o = o.contiguous().view(batch, out_channels, -1)
|
||||
return o
|
|
@ -0,0 +1,153 @@
|
|||
import importlib
|
||||
import re
|
||||
|
||||
from coqpit import Coqpit
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_model(config: Coqpit):
|
||||
"""Load models directly from configuration."""
|
||||
print(" > Vocoder Model: {}".format(config.model))
|
||||
if "discriminator_model" in config and "generator_model" in config:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models.gan")
|
||||
MyModel = getattr(MyModel, "GAN")
|
||||
else:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models." + config.model.lower())
|
||||
if config.model.lower() == "wavernn":
|
||||
MyModel = getattr(MyModel, "Wavernn")
|
||||
elif config.model.lower() == "gan":
|
||||
MyModel = getattr(MyModel, "GAN")
|
||||
elif config.model.lower() == "wavegrad":
|
||||
MyModel = getattr(MyModel, "Wavegrad")
|
||||
else:
|
||||
MyModel = getattr(MyModel, to_camel(config.model))
|
||||
raise ValueError(f"Model {config.model} not exist!")
|
||||
model = MyModel(config)
|
||||
return model
|
||||
|
||||
|
||||
def setup_generator(c):
|
||||
""" TODO: use config object as arguments"""
|
||||
print(" > Generator Model: {}".format(c.generator_model))
|
||||
MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.generator_model))
|
||||
# this is to preserve the Wavernn class name (instead of Wavernn)
|
||||
if c.generator_model.lower() in "hifigan_generator":
|
||||
model = MyModel(in_channels=c.audio["num_mels"], out_channels=1, **c.generator_model_params)
|
||||
elif c.generator_model.lower() in "melgan_generator":
|
||||
model = MyModel(
|
||||
in_channels=c.audio["num_mels"],
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
)
|
||||
elif c.generator_model in "melgan_fb_generator":
|
||||
raise ValueError("melgan_fb_generator is now fullband_melgan_generator")
|
||||
elif c.generator_model.lower() in "multiband_melgan_generator":
|
||||
model = MyModel(
|
||||
in_channels=c.audio["num_mels"],
|
||||
out_channels=4,
|
||||
proj_kernel=7,
|
||||
base_channels=384,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
)
|
||||
elif c.generator_model.lower() in "fullband_melgan_generator":
|
||||
model = MyModel(
|
||||
in_channels=c.audio["num_mels"],
|
||||
out_channels=1,
|
||||
proj_kernel=7,
|
||||
base_channels=512,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
res_kernel=3,
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
)
|
||||
elif c.generator_model.lower() in "parallel_wavegan_generator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_res_blocks=c.generator_model_params["num_res_blocks"],
|
||||
stacks=c.generator_model_params["stacks"],
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
aux_channels=c.audio["num_mels"],
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
use_weight_norm=True,
|
||||
upsample_factors=c.generator_model_params["upsample_factors"],
|
||||
)
|
||||
elif c.generator_model.lower() in "univnet_generator":
|
||||
model = MyModel(**c.generator_model_params)
|
||||
else:
|
||||
raise NotImplementedError(f"Model {c.generator_model} not implemented!")
|
||||
return model
|
||||
|
||||
|
||||
def setup_discriminator(c):
|
||||
""" TODO: use config objekt as arguments"""
|
||||
print(" > Discriminator Model: {}".format(c.discriminator_model))
|
||||
if "parallel_wavegan" in c.discriminator_model:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator")
|
||||
else:
|
||||
MyModel = importlib.import_module("TTS.vocoder.models." + c.discriminator_model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower()))
|
||||
if c.discriminator_model in "hifigan_discriminator":
|
||||
model = MyModel()
|
||||
if c.discriminator_model in "random_window_discriminator":
|
||||
model = MyModel(
|
||||
cond_channels=c.audio["num_mels"],
|
||||
hop_length=c.audio["hop_length"],
|
||||
uncond_disc_donwsample_factors=c.discriminator_model_params["uncond_disc_donwsample_factors"],
|
||||
cond_disc_downsample_factors=c.discriminator_model_params["cond_disc_downsample_factors"],
|
||||
cond_disc_out_channels=c.discriminator_model_params["cond_disc_out_channels"],
|
||||
window_sizes=c.discriminator_model_params["window_sizes"],
|
||||
)
|
||||
if c.discriminator_model in "melgan_multiscale_discriminator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_sizes=(5, 3),
|
||||
base_channels=c.discriminator_model_params["base_channels"],
|
||||
max_channels=c.discriminator_model_params["max_channels"],
|
||||
downsample_factors=c.discriminator_model_params["downsample_factors"],
|
||||
)
|
||||
if c.discriminator_model == "residual_parallel_wavegan_discriminator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=c.discriminator_model_params["num_layers"],
|
||||
stacks=c.discriminator_model_params["stacks"],
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
dropout=0.0,
|
||||
bias=True,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
)
|
||||
if c.discriminator_model == "parallel_wavegan_discriminator":
|
||||
model = MyModel(
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
num_layers=c.discriminator_model_params["num_layers"],
|
||||
conv_channels=64,
|
||||
dilation_factor=1,
|
||||
nonlinear_activation="LeakyReLU",
|
||||
nonlinear_activation_params={"negative_slope": 0.2},
|
||||
bias=True,
|
||||
)
|
||||
if c.discriminator_model == "univnet_discriminator":
|
||||
model = MyModel()
|
||||
return model
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue