mirror of https://github.com/coqui-ai/TTS.git
compute_attention_masks.py
parent
cf869e8922
commit
2abe3df153
|
@ -0,0 +1,170 @@
|
|||
"""Compute attention masks from pre-trained Tacotron or Tacotron2 models.
|
||||
Sample run on LJSpeech dataset.
|
||||
|
||||
>>>> CUDA_VISIBLE_DEVICES="0" python TTS/bin/compute_attention_masks.py \
|
||||
--model_path /home/erogol/Cluster/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_100000.pth.tar \
|
||||
--config_path /home/erogol/Cluster/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json --dataset ljspeech \
|
||||
--dataset_metafile /home/erogol/Data/LJSpeech-1.1/metadata.csv \
|
||||
--data_path /home/erogol/Data/LJSpeech-1.1/ \
|
||||
--batch_size 16 \
|
||||
--use_cuda true
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import importlib
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||
from TTS.tts.utils.generic_utils import sequence_mask, setup_model
|
||||
from TTS.tts.utils.io import load_checkpoint
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.io import load_config
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Extract attention masks from trained Tacotron models.')
|
||||
parser.add_argument('--model_path',
|
||||
type=str,
|
||||
help='Path to Tacotron or Tacotron2 model file ')
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Path to config file for training.',
|
||||
)
|
||||
parser.add_argument('--dataset',
|
||||
type=str,
|
||||
default='',
|
||||
help='Dataset from TTS.tts.dataset.preprocess.')
|
||||
|
||||
parser.add_argument(
|
||||
'--dataset_metafile',
|
||||
type=str,
|
||||
default='',
|
||||
help='Dataset metafile inclusing file paths with transcripts.')
|
||||
parser.add_argument(
|
||||
'--data_path',
|
||||
type=str,
|
||||
default='',
|
||||
help='Defines the data path. It overwrites config.json.')
|
||||
parser.add_argument('--output_path',
|
||||
type=str,
|
||||
help='path for training outputs.',
|
||||
default='')
|
||||
parser.add_argument('--output_folder',
|
||||
type=str,
|
||||
default='',
|
||||
help='folder name for training outputs.')
|
||||
|
||||
parser.add_argument('--use_cuda',
|
||||
type=bool,
|
||||
default=False,
|
||||
help="enable/disable cuda.")
|
||||
|
||||
parser.add_argument(
|
||||
'--batch_size',
|
||||
default=16,
|
||||
type=int,
|
||||
help='Batch size for the model. Use batch_size=1 if you have no CUDA.')
|
||||
args = parser.parse_args()
|
||||
|
||||
C = load_config(args.config_path)
|
||||
ap = AudioProcessor(**C.audio)
|
||||
|
||||
# if the vocabulary was passed, replace the default
|
||||
if 'characters' in C.keys():
|
||||
symbols, phonemes = make_symbols(**C.characters)
|
||||
|
||||
# load the model
|
||||
num_chars = len(phonemes) if C.use_phonemes else len(symbols)
|
||||
# TODO: handle multi-speaker
|
||||
model = setup_model(num_chars, num_speakers=0, c=C)
|
||||
model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda)
|
||||
model.eval()
|
||||
|
||||
# data loader
|
||||
preprocessor = importlib.import_module('TTS.tts.datasets.preprocess')
|
||||
preprocessor = getattr(preprocessor, args.dataset)
|
||||
meta_data = preprocessor(args.data_path, args.dataset_metafile)
|
||||
dataset = MyDataset(model.decoder.r,
|
||||
C.text_cleaner,
|
||||
compute_linear_spec=False,
|
||||
ap=ap,
|
||||
meta_data=meta_data,
|
||||
tp=C.characters if 'characters' in C.keys() else None,
|
||||
add_blank=c['add_blank'] if 'add_blank' in C.keys() else False,
|
||||
use_phonemes=C.use_phonemes,
|
||||
phoneme_cache_path=C.phoneme_cache_path,
|
||||
phoneme_language=C.phoneme_language,
|
||||
enable_eos_bos=C.enable_eos_bos_chars)
|
||||
|
||||
dataset.sort_items()
|
||||
loader = DataLoader(dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=4,
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=False,
|
||||
drop_last=False)
|
||||
|
||||
# compute attentions
|
||||
file_paths = []
|
||||
with torch.no_grad():
|
||||
for data in tqdm(loader):
|
||||
# setup input data
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[3]
|
||||
mel_input = data[4]
|
||||
mel_lengths = data[5]
|
||||
stop_targets = data[6]
|
||||
item_idxs = data[7]
|
||||
|
||||
# dispatch data to GPU
|
||||
if args.use_cuda:
|
||||
text_input = text_input.cuda()
|
||||
text_lengths = text_lengths.cuda()
|
||||
mel_input = mel_input.cuda()
|
||||
mel_lengths = mel_lengths.cuda()
|
||||
|
||||
mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(
|
||||
text_input, text_lengths, mel_input)
|
||||
|
||||
alignments = alignments.detach()
|
||||
for idx, alignment in enumerate(alignments):
|
||||
item_idx = item_idxs[idx]
|
||||
# interpolate if r > 1
|
||||
alignment = torch.nn.functional.interpolate(
|
||||
alignment.transpose(0, 1).unsqueeze(0),
|
||||
size=None,
|
||||
scale_factor=model.decoder.r,
|
||||
mode='nearest',
|
||||
align_corners=None,
|
||||
recompute_scale_factor=None).squeeze(0).transpose(0, 1)
|
||||
|
||||
# remove paddings
|
||||
alignment = alignment[:mel_lengths[idx], :text_lengths[idx]].cpu().numpy()
|
||||
|
||||
# set file paths
|
||||
wav_file_name = os.path.basename(item_idx)
|
||||
align_file_name = os.path.splitext(wav_file_name)[0] + '.npy'
|
||||
file_path = item_idx.replace(wav_file_name, align_file_name)
|
||||
# save output
|
||||
file_paths.append([item_idx, file_path])
|
||||
np.save(file_path, alignment)
|
||||
|
||||
# ourpur metafile
|
||||
metafile = os.path.join(args.data_path, "metadata_attn_mask.txt")
|
||||
|
||||
with open(metafile, "w") as f:
|
||||
for p in file_paths:
|
||||
f.write(f"{p[0]}|{p[1]}\n")
|
||||
print(f" >> Metafile created: {metafile}")
|
Loading…
Reference in New Issue