Update notebooks

pull/10/head
Eren Golge 2019-06-03 11:25:43 +02:00
parent f29d1840f9
commit 11d5e5bae5
2 changed files with 67 additions and 30 deletions

View File

@ -105,21 +105,21 @@
"outputs": [],
"source": [
"# Set constants\n",
"ROOT_PATH = '/media/erogol/data_ssd/Data/models/mozilla_models/4842/'\n",
"ROOT_PATH = '/media/erogol/data_ssd/Data/models/mozilla_models/4845/'\n",
"MODEL_PATH = ROOT_PATH + '/best_model.pth.tar'\n",
"CONFIG_PATH = ROOT_PATH + '/config.json'\n",
"OUT_FOLDER = '/home/erogol/Dropbox/AudioSamples/benchmark_samples/'\n",
"CONFIG = load_config(CONFIG_PATH)\n",
"VOCODER_MODEL_PATH = \"/media/erogol/data_ssd/Data/models/wavernn/mozilla/mozilla-4841-May-26-2019_01+50PM-df8cfe1/model_checkpoints/best_model.pth.tar\"\n",
"VOCODER_CONFIG_PATH = \"/media/erogol/data_ssd/Data/models/wavernn/mozilla/mozilla-4841-May-26-2019_04+23AM-df8cfe1/config.json\"\n",
"VOCODER_MODEL_PATH = \"/media/erogol/data_ssd/Data/models/wavernn/mozilla/mozilla-May24-4763/model_checkpoints/best_model.pth.tar\"\n",
"VOCODER_CONFIG_PATH = \"/media/erogol/data_ssd/Data/models/wavernn/mozilla/mozilla-May24-4763/config.json\"\n",
"VOCODER_CONFIG = load_config(VOCODER_CONFIG_PATH)\n",
"use_cuda = False\n",
"\n",
"# Set some config fields manually for testing\n",
"CONFIG.windowing = False\n",
"CONFIG.prenet_dropout = False\n",
"CONFIG.separate_stopnet = True\n",
"CONFIG.stopnet = True\n",
"# CONFIG.windowing = False\n",
"# CONFIG.prenet_dropout = False\n",
"# CONFIG.separate_stopnet = True\n",
"# CONFIG.stopnet = True\n",
"\n",
"# Set the vocoder\n",
"use_gl = True # use GL if True\n",

View File

@ -22,6 +22,8 @@
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"import os\n",
"import sys\n",
"sys.path.append(TTS_PATH)\n",
@ -34,12 +36,12 @@
"from TTS.datasets.TTSDataset import MyDataset\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.utils.visual import plot_spectrogram\n",
"from TTS.utils.generic_utils import load_config\n",
"from TTS.utils.generic_utils import load_config, setup_model\n",
"from TTS.datasets.preprocess import ljspeech\n",
"%matplotlib inline\n",
"\n",
"import os\n",
"os.environ['CUDA_VISIBLE_DEVICES']='0'"
"os.environ['CUDA_VISIBLE_DEVICES']='1'"
]
},
{
@ -66,18 +68,22 @@
"metadata": {},
"outputs": [],
"source": [
"OUT_PATH = \"/home/erogol/Data/LJSpeech-1.1/wavernn_4152/\"\n",
"DATA_PATH = \"/home/erogol/Data/LJSpeech-1.1/\"\n",
"METADATA_FILE = \"metadata_train.csv\"\n",
"CONFIG_PATH = \"/media/erogol/data_ssd/Data/models/ljspeech_models/4258_nancy/config.json\"\n",
"MODEL_FILE = \"/home/erogol/checkpoint_92000.pth.tar\"\n",
"DRY_RUN = True # if False, does not generate output files, only computes loss and visuals.\n",
"BATCH_SIZE = 16\n",
"OUT_PATH = \"/home/erogol/Data/Mozilla/wavernn/4841/\"\n",
"DATA_PATH = \"/home/erogol/Data/Mozilla/\"\n",
"DATASET = \"mozilla\"\n",
"METADATA_FILE = \"metadata.txt\"\n",
"CONFIG_PATH = \"/media/erogol/data_ssd/Data/models/mozilla_models/4841/config.json\"\n",
"MODEL_FILE = \"/media/erogol/data_ssd/Data/models/mozilla_models/4841/best_model.pth.tar\"\n",
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
"BATCH_SIZE = 32\n",
"\n",
"use_cuda = torch.cuda.is_available()\n",
"print(\" > CUDA enabled: \", use_cuda)\n",
"\n",
"C = load_config(CONFIG_PATH)\n",
"ap = AudioProcessor(bits=9, **C.audio)"
"ap = AudioProcessor(bits=9, **C.audio)\n",
"C.prenet_dropout = False\n",
"C.separate_stopnet = True"
]
},
{
@ -86,7 +92,10 @@
"metadata": {},
"outputs": [],
"source": [
"dataset = MyDataset(DATA_PATH, METADATA_FILE, C.r, C.text_cleaner, ap, ljspeech, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path)\n",
"preprocessor = importlib.import_module('datasets.preprocess')\n",
"preprocessor = getattr(preprocessor, DATASET.lower())\n",
"\n",
"dataset = MyDataset(DATA_PATH, METADATA_FILE, C.r, C.text_cleaner, ap, preprocessor, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path)\n",
"loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)"
]
},
@ -99,11 +108,11 @@
"from utils.text.symbols import symbols, phonemes\n",
"from utils.generic_utils import sequence_mask\n",
"from layers.losses import L1LossMasked\n",
"from utils.text.symbols import symbols, phonemes\n",
"\n",
"# load the model\n",
"MyModel = importlib.import_module('TTS.models.'+C.model.lower())\n",
"MyModel = getattr(MyModel, C.model)\n",
"num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n",
"model = MyModel(num_chars, C.r, attn_win=False)\n",
"model = setup_model(num_chars, C)\n",
"checkpoint = torch.load(MODEL_FILE)\n",
"model.load_state_dict(checkpoint['model'])\n",
"print(checkpoint['step'])\n",
@ -151,10 +160,19 @@
" stop_targets = stop_targets.cuda()\n",
" \n",
" mask = sequence_mask(text_lengths)\n",
" mel_outputs, mel_postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input, mask)\n",
" mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n",
" \n",
" # compute mel specs from linear spec if model is Tacotron\n",
" mel_specs = []\n",
" if C.model == \"Tacotron\":\n",
" postnet_outputs = postnet_outputs.data.cpu().numpy()\n",
" for b in range(postnet_outputs.shape[0]):\n",
" postnet_output = postnet_outputs[b]\n",
" mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n",
" postnet_outputs = torch.stack(mel_specs)\n",
" \n",
" loss = criterion(mel_outputs, mel_input, mel_lengths)\n",
" loss_postnet = criterion(mel_postnet_outputs, mel_input, mel_lengths)\n",
" loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n",
" losses.append(loss.item())\n",
" postnet_losses.append(loss_postnet.item())\n",
" if not DRY_RUN:\n",
@ -164,12 +182,12 @@
" file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n",
" file_idxs.append(file_name)\n",
"\n",
" # quantize and save wav\n",
" wavq = ap.quantize(wav)\n",
" np.save(wavq_path, wavq)\n",
"# # quantize and save wav\n",
"# wavq = ap.quantize(wav)\n",
"# np.save(wavq_path, wavq)\n",
"\n",
" # save TTS mel\n",
" mel = mel_postnet_outputs[idx]\n",
" mel = postnet_outputs[idx]\n",
" mel = mel.data.cpu().numpy()\n",
" mel_length = mel_lengths[idx]\n",
" mel = mel[:mel_length, :].T\n",
@ -202,7 +220,18 @@
"outputs": [],
"source": [
"idx = 1\n",
"mel_example = mel_postnet_outputs[idx].data.cpu().numpy()\n",
"mel_example = postnet_outputs[idx].data.cpu().numpy()\n",
"plot_spectrogram(mel_example[:mel_lengths[idx], :], ap);\n",
"print(mel_example[:mel_lengths[1], :].shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mel_example = mel_outputs[idx].data.cpu().numpy()\n",
"plot_spectrogram(mel_example[:mel_lengths[idx], :], ap);\n",
"print(mel_example[:mel_lengths[1], :].shape)"
]
@ -225,8 +254,9 @@
"metadata": {},
"outputs": [],
"source": [
"# postnet, decoder diff\n",
"from matplotlib import pylab as plt\n",
"mel_diff = mel_outputs[idx] - mel_postnet_outputs[idx]\n",
"mel_diff = mel_outputs[idx] - postnet_outputs[idx]\n",
"plt.figure(figsize=(16, 10))\n",
"plt.imshow(abs(mel_diff.detach().cpu().numpy()[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.colorbar()\n",
@ -241,13 +271,20 @@
"source": [
"from matplotlib import pylab as plt\n",
"# mel = mel_poutputs[idx].detach().cpu().numpy()\n",
"mel = mel_postnet_outputs[idx].detach().cpu().numpy()\n",
"mel = postnet_outputs[idx].detach().cpu().numpy()\n",
"mel_diff2 = melt.T - mel[:melt.shape[1]]\n",
"plt.figure(figsize=(16, 10))\n",
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n",
"plt.colorbar()\n",
"plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {