{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "This is a notebook to generate mel-spectrograms from a TTS model to be used for WaveRNN training." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "import os\n", "import sys\n", "import torch\n", "import importlib\n", "import numpy as np\n", "from tqdm import tqdm as tqdm\n", "from torch.utils.data import DataLoader\n", "from mozilla_voice_tts.tts.datasets.TTSDataset import MyDataset\n", "from mozilla_voice_tts.tts.layers.losses import L1LossMasked\n", "from mozilla_voice_tts.tts.utils.audio import AudioProcessor\n", "from mozilla_voice_tts.tts.utils.visual import plot_spectrogram\n", "from mozilla_voice_tts.tts.utils.generic_utils import load_config, setup_model, sequence_mask\n", "from mozilla_voice_tts.tts.utils.text.symbols import make_symbols, symbols, phonemes\n", "\n", "%matplotlib inline\n", "\n", "import os\n", "os.environ['CUDA_VISIBLE_DEVICES']='0'" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def set_filename(wav_path, out_path):\n", " wav_file = os.path.basename(wav_path)\n", " file_name = wav_file.split('.')[0]\n", " os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n", " os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n", " os.makedirs(os.path.join(out_path, \"wav_gl\"), exist_ok=True)\n", " wavq_path = os.path.join(out_path, \"quant\", file_name)\n", " mel_path = os.path.join(out_path, \"mel\", file_name)\n", " wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n", " return file_name, wavq_path, mel_path, wav_path" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "OUT_PATH = \"/home/erogol/Data/LJSpeech-1.1/ljspeech-March-17-2020_01+16AM-871588c/\"\n", "DATA_PATH = \"/home/erogol/Data/LJSpeech-1.1/\"\n", "DATASET = \"ljspeech\"\n", "METADATA_FILE = \"metadata.csv\"\n", "CONFIG_PATH = \"/home/erogol/Models/LJSpeech/ljspeech-March-17-2020_01+16AM-871588c/config.json\"\n", "MODEL_FILE = \"/home/erogol/Models/LJSpeech/ljspeech-March-17-2020_01+16AM-871588c/checkpoint_420000.pth.tar\"\n", "BATCH_SIZE = 32\n", "\n", "QUANTIZED_WAV = False\n", "QUANTIZE_BIT = 9\n", "DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n", "\n", "use_cuda = torch.cuda.is_available()\n", "print(\" > CUDA enabled: \", use_cuda)\n", "\n", "C = load_config(CONFIG_PATH)\n", "C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n", "ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# if the vocabulary was passed, replace the default\n", "if 'characters' in C.keys():\n", " symbols, phonemes = make_symbols(**C.characters)\n", "\n", "# load the model\n", "num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n", "# TODO: multiple speaker\n", "model = setup_model(num_chars, num_speakers=0, c=C)\n", "checkpoint = torch.load(MODEL_FILE)\n", "model.load_state_dict(checkpoint['model'])\n", "print(checkpoint['step'])\n", "model.eval()\n", "model.decoder.set_r(checkpoint['r'])\n", "if use_cuda:\n", " model = model.cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "preprocessor = importlib.import_module('mozilla_voice_tts.tts.datasets.preprocess')\n", "preprocessor = getattr(preprocessor, DATASET.lower())\n", "meta_data = preprocessor(DATA_PATH,METADATA_FILE)\n", "dataset = MyDataset(checkpoint['r'], C.text_cleaner, False, ap, meta_data,tp=C.characters if 'characters' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n", "loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Generate model outputs " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "\n", "file_idxs = []\n", "metadata = []\n", "losses = []\n", "postnet_losses = []\n", "criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n", "with torch.no_grad():\n", " for data in tqdm(loader):\n", " # setup input data\n", " text_input = data[0]\n", " text_lengths = data[1]\n", " linear_input = data[3]\n", " mel_input = data[4]\n", " mel_lengths = data[5]\n", " stop_targets = data[6]\n", " item_idx = data[7]\n", "\n", " # dispatch data to GPU\n", " if use_cuda:\n", " text_input = text_input.cuda()\n", " text_lengths = text_lengths.cuda()\n", " mel_input = mel_input.cuda()\n", " mel_lengths = mel_lengths.cuda()\n", "\n", " mask = sequence_mask(text_lengths)\n", " mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n", " \n", " # compute loss\n", " loss = criterion(mel_outputs, mel_input, mel_lengths)\n", " loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n", " losses.append(loss.item())\n", " postnet_losses.append(loss_postnet.item())\n", "\n", " # compute mel specs from linear spec if model is Tacotron\n", " if C.model == \"Tacotron\":\n", " mel_specs = []\n", " postnet_outputs = postnet_outputs.data.cpu().numpy()\n", " for b in range(postnet_outputs.shape[0]):\n", " postnet_output = postnet_outputs[b]\n", " mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n", " postnet_outputs = torch.stack(mel_specs)\n", " elif C.model == \"Tacotron2\":\n", " postnet_outputs = postnet_outputs.detach().cpu().numpy()\n", " alignments = alignments.detach().cpu().numpy()\n", "\n", " if not DRY_RUN:\n", " for idx in range(text_input.shape[0]):\n", " wav_file_path = item_idx[idx]\n", " wav = ap.load_wav(wav_file_path)\n", " file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n", " file_idxs.append(file_name)\n", "\n", " # quantize and save wav\n", " if QUANTIZED_WAV:\n", " wavq = ap.quantize(wav)\n", " np.save(wavq_path, wavq)\n", "\n", " # save TTS mel\n", " mel = postnet_outputs[idx]\n", " mel_length = mel_lengths[idx]\n", " mel = mel[:mel_length, :].T\n", " np.save(mel_path, mel)\n", "\n", " metadata.append([wav_file_path, mel_path])\n", "\n", " # for wavernn\n", " if not DRY_RUN:\n", " pickle.dump(file_idxs, open(OUT_PATH+\"/dataset_ids.pkl\", \"wb\")) \n", " \n", " # for pwgan\n", " with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n", " for data in metadata:\n", " f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n", "\n", " print(np.mean(losses))\n", " print(np.mean(postnet_losses))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# for pwgan\n", "with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n", " for data in metadata:\n", " f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sanity Check" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "idx = 1\n", "ap.melspectrogram(ap.load_wav(item_idx[idx])).shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import soundfile as sf\n", "wav, sr = sf.read(item_idx[idx])\n", "mel_postnet = postnet_outputs[idx][:mel_lengths[idx], :]\n", "mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n", "mel_truth = ap.melspectrogram(wav)\n", "print(mel_truth.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# plot posnet output\n", "plot_spectrogram(mel_postnet, ap);\n", "print(mel_postnet[:mel_lengths[idx], :].shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# plot decoder output\n", "plot_spectrogram(mel_decoder, ap);\n", "print(mel_decoder.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# plot GT specgrogram\n", "print(mel_truth.shape)\n", "plot_spectrogram(mel_truth.T, ap);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# postnet, decoder diff\n", "from matplotlib import pylab as plt\n", "mel_diff = mel_decoder - mel_postnet\n", "plt.figure(figsize=(16, 10))\n", "plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n", "plt.colorbar()\n", "plt.tight_layout()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# PLOT GT SPECTROGRAM diff\n", "from matplotlib import pylab as plt\n", "mel_diff2 = mel_truth.T - mel_decoder\n", "plt.figure(figsize=(16, 10))\n", "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", "plt.colorbar()\n", "plt.tight_layout()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# PLOT GT SPECTROGRAM diff\n", "from matplotlib import pylab as plt\n", "mel = postnet_outputs[idx]\n", "mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n", "plt.figure(figsize=(16, 10))\n", "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", "plt.colorbar()\n", "plt.tight_layout()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 4 }