Update notebooks

pull/887/head
Eren Gölge 2021-10-21 16:20:14 +00:00
parent 5e0d0539c5
commit 016803beee
5 changed files with 243 additions and 69 deletions

View File

@ -2,14 +2,16 @@
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is a notebook to generate mel-spectrograms from a TTS model to be used in a Vocoder training."
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
@ -20,7 +22,7 @@
"import numpy as np\n",
"from tqdm import tqdm as tqdm\n",
"from torch.utils.data import DataLoader\n",
"from TTS.tts.datasets.TTSDataset import TTSDataset\n",
"from TTS.tts.datasets.dataset import TTSDataset\n",
"from TTS.tts.layers.losses import L1LossMasked\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.config import load_config\n",
@ -33,13 +35,13 @@
"\n",
"import os\n",
"os.environ['CUDA_VISIBLE_DEVICES']='2'"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def set_filename(wav_path, out_path):\n",
" wav_file = os.path.basename(wav_path)\n",
@ -51,13 +53,13 @@
" mel_path = os.path.join(out_path, \"mel\", file_name)\n",
" wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n",
" return file_name, wavq_path, mel_path, wav_path"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
@ -77,13 +79,13 @@
"C = load_config(CONFIG_PATH)\n",
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
"ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(C['r'])\n",
"# if the vocabulary was passed, replace the default\n",
@ -95,13 +97,13 @@
"# TODO: multiple speaker\n",
"model = setup_model(C)\n",
"model.load_checkpoint(C, MODEL_FILE, eval=True)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n",
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
@ -120,20 +122,20 @@
"loader = DataLoader(\n",
" dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
")\n"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generate model outputs "
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pickle\n",
"\n",
@ -212,42 +214,42 @@
"\n",
" print(np.mean(losses))\n",
" print(np.mean(postnet_losses))"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# for pwgan\n",
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
" for data in metadata:\n",
" f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sanity Check"
],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"idx = 1\n",
"ap.melspectrogram(ap.load_wav(item_idx[idx])).shape"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import soundfile as sf\n",
"wav, sr = sf.read(item_idx[idx])\n",
@ -255,46 +257,46 @@
"mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n",
"mel_truth = ap.melspectrogram(wav)\n",
"print(mel_truth.shape)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot posnet output\n",
"print(mel_postnet[:mel_lengths[idx], :].shape)\n",
"plot_spectrogram(mel_postnet, ap)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot decoder output\n",
"print(mel_decoder.shape)\n",
"plot_spectrogram(mel_decoder, ap)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot GT specgrogram\n",
"print(mel_truth.shape)\n",
"plot_spectrogram(mel_truth.T, ap)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# postnet, decoder diff\n",
"from matplotlib import pylab as plt\n",
@ -303,13 +305,13 @@
"plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.colorbar()\n",
"plt.tight_layout()"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n",
@ -318,13 +320,13 @@
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.colorbar()\n",
"plt.tight_layout()"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# PLOT GT SPECTROGRAM diff\n",
"from matplotlib import pylab as plt\n",
@ -334,22 +336,23 @@
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.colorbar()\n",
"plt.tight_layout()"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "code",
"execution_count": null,
"source": [],
"metadata": {},
"outputs": [],
"metadata": {}
"source": []
}
],
"metadata": {
"interpreter": {
"hash": "822ce188d9bce5372c4adbb11364eeb49293228c2224eb55307f4664778e7f56"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3.9.7 64-bit ('base': conda)"
"display_name": "Python 3.9.7 64-bit ('base': conda)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
@ -362,11 +365,8 @@
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
},
"interpreter": {
"hash": "822ce188d9bce5372c4adbb11364eeb49293228c2224eb55307f4664778e7f56"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
}

View File

@ -19,19 +19,16 @@
"source": [
"import os\n",
"import glob\n",
"import random\n",
"import numpy as np\n",
"import torch\n",
"import umap\n",
"\n",
"from TTS.speaker_encoder.model import SpeakerEncoder\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.tts.utils.generic_utils import load_config\n",
"from TTS.config import load_config\n",
"\n",
"from bokeh.io import output_notebook, show\n",
"from bokeh.plotting import figure\n",
"from bokeh.models import HoverTool, ColumnDataSource, BoxZoomTool, ResetTool, OpenURL, TapTool\n",
"from bokeh.transform import factor_cmap, factor_mark\n",
"from bokeh.transform import factor_cmap\n",
"from bokeh.palettes import Category10"
]
},

View File

@ -22,7 +22,6 @@
"import os\n",
"import sys\n",
"sys.path.append(TTS_PATH) # set this if TTS is not installed globally\n",
"import glob\n",
"import librosa\n",
"import numpy as np\n",
"import pandas as pd\n",

View File

@ -21,10 +21,9 @@
"metadata": {},
"outputs": [],
"source": [
"import os, sys\n",
"import os\n",
"import glob\n",
"import subprocess\n",
"import tempfile\n",
"import IPython\n",
"import soundfile as sf\n",
"import numpy as np\n",
@ -208,4 +207,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}

File diff suppressed because one or more lines are too long