compute_attention_masks.py

pull/10/head
erogol 2020-12-28 14:44:01 +01:00
parent cf869e8922
commit 2abe3df153
1 changed files with 170 additions and 0 deletions

View File

@ -0,0 +1,170 @@
"""Compute attention masks from pre-trained Tacotron or Tacotron2 models.
Sample run on LJSpeech dataset.
>>>> CUDA_VISIBLE_DEVICES="0" python TTS/bin/compute_attention_masks.py \
--model_path /home/erogol/Cluster/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_100000.pth.tar \
--config_path /home/erogol/Cluster/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json --dataset ljspeech \
--dataset_metafile /home/erogol/Data/LJSpeech-1.1/metadata.csv \
--data_path /home/erogol/Data/LJSpeech-1.1/ \
--batch_size 16 \
--use_cuda true
"""
import argparse
import glob
import importlib
import os
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from TTS.tts.datasets.TTSDataset import MyDataset
from TTS.tts.utils.generic_utils import sequence_mask, setup_model
from TTS.tts.utils.io import load_checkpoint
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Extract attention masks from trained Tacotron models.')
parser.add_argument('--model_path',
type=str,
help='Path to Tacotron or Tacotron2 model file ')
parser.add_argument(
'--config_path',
type=str,
required=True,
help='Path to config file for training.',
)
parser.add_argument('--dataset',
type=str,
default='',
help='Dataset from TTS.tts.dataset.preprocess.')
parser.add_argument(
'--dataset_metafile',
type=str,
default='',
help='Dataset metafile inclusing file paths with transcripts.')
parser.add_argument(
'--data_path',
type=str,
default='',
help='Defines the data path. It overwrites config.json.')
parser.add_argument('--output_path',
type=str,
help='path for training outputs.',
default='')
parser.add_argument('--output_folder',
type=str,
default='',
help='folder name for training outputs.')
parser.add_argument('--use_cuda',
type=bool,
default=False,
help="enable/disable cuda.")
parser.add_argument(
'--batch_size',
default=16,
type=int,
help='Batch size for the model. Use batch_size=1 if you have no CUDA.')
args = parser.parse_args()
C = load_config(args.config_path)
ap = AudioProcessor(**C.audio)
# if the vocabulary was passed, replace the default
if 'characters' in C.keys():
symbols, phonemes = make_symbols(**C.characters)
# load the model
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
# TODO: handle multi-speaker
model = setup_model(num_chars, num_speakers=0, c=C)
model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda)
model.eval()
# data loader
preprocessor = importlib.import_module('TTS.tts.datasets.preprocess')
preprocessor = getattr(preprocessor, args.dataset)
meta_data = preprocessor(args.data_path, args.dataset_metafile)
dataset = MyDataset(model.decoder.r,
C.text_cleaner,
compute_linear_spec=False,
ap=ap,
meta_data=meta_data,
tp=C.characters if 'characters' in C.keys() else None,
add_blank=c['add_blank'] if 'add_blank' in C.keys() else False,
use_phonemes=C.use_phonemes,
phoneme_cache_path=C.phoneme_cache_path,
phoneme_language=C.phoneme_language,
enable_eos_bos=C.enable_eos_bos_chars)
dataset.sort_items()
loader = DataLoader(dataset,
batch_size=args.batch_size,
num_workers=4,
collate_fn=dataset.collate_fn,
shuffle=False,
drop_last=False)
# compute attentions
file_paths = []
with torch.no_grad():
for data in tqdm(loader):
# setup input data
text_input = data[0]
text_lengths = data[1]
linear_input = data[3]
mel_input = data[4]
mel_lengths = data[5]
stop_targets = data[6]
item_idxs = data[7]
# dispatch data to GPU
if args.use_cuda:
text_input = text_input.cuda()
text_lengths = text_lengths.cuda()
mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda()
mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(
text_input, text_lengths, mel_input)
alignments = alignments.detach()
for idx, alignment in enumerate(alignments):
item_idx = item_idxs[idx]
# interpolate if r > 1
alignment = torch.nn.functional.interpolate(
alignment.transpose(0, 1).unsqueeze(0),
size=None,
scale_factor=model.decoder.r,
mode='nearest',
align_corners=None,
recompute_scale_factor=None).squeeze(0).transpose(0, 1)
# remove paddings
alignment = alignment[:mel_lengths[idx], :text_lengths[idx]].cpu().numpy()
# set file paths
wav_file_name = os.path.basename(item_idx)
align_file_name = os.path.splitext(wav_file_name)[0] + '.npy'
file_path = item_idx.replace(wav_file_name, align_file_name)
# save output
file_paths.append([item_idx, file_path])
np.save(file_path, alignment)
# ourpur metafile
metafile = os.path.join(args.data_path, "metadata_attn_mask.txt")
with open(metafile, "w") as f:
for p in file_paths:
f.write(f"{p[0]}|{p[1]}\n")
print(f" >> Metafile created: {metafile}")