From 2abe3df1536ba956bc03a471716be3803f66e7c2 Mon Sep 17 00:00:00 2001 From: erogol Date: Mon, 28 Dec 2020 14:44:01 +0100 Subject: [PATCH] compute_attention_masks.py --- TTS/bin/compute_attention_masks.py | 170 +++++++++++++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 TTS/bin/compute_attention_masks.py diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py new file mode 100644 index 00000000..554e4269 --- /dev/null +++ b/TTS/bin/compute_attention_masks.py @@ -0,0 +1,170 @@ +"""Compute attention masks from pre-trained Tacotron or Tacotron2 models. +Sample run on LJSpeech dataset. + + >>>> CUDA_VISIBLE_DEVICES="0" python TTS/bin/compute_attention_masks.py \ + --model_path /home/erogol/Cluster/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_100000.pth.tar \ + --config_path /home/erogol/Cluster/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/config.json --dataset ljspeech \ + --dataset_metafile /home/erogol/Data/LJSpeech-1.1/metadata.csv \ + --data_path /home/erogol/Data/LJSpeech-1.1/ \ + --batch_size 16 \ + --use_cuda true + +""" + + +import argparse +import glob +import importlib +import os + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from TTS.tts.datasets.TTSDataset import MyDataset +from TTS.tts.utils.generic_utils import sequence_mask, setup_model +from TTS.tts.utils.io import load_checkpoint +from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols +from TTS.utils.audio import AudioProcessor +from TTS.utils.io import load_config + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Extract attention masks from trained Tacotron models.') + parser.add_argument('--model_path', + type=str, + help='Path to Tacotron or Tacotron2 model file ') + parser.add_argument( + '--config_path', + type=str, + required=True, + help='Path to config file for training.', + ) + parser.add_argument('--dataset', + type=str, + default='', + help='Dataset from TTS.tts.dataset.preprocess.') + + parser.add_argument( + '--dataset_metafile', + type=str, + default='', + help='Dataset metafile inclusing file paths with transcripts.') + parser.add_argument( + '--data_path', + type=str, + default='', + help='Defines the data path. It overwrites config.json.') + parser.add_argument('--output_path', + type=str, + help='path for training outputs.', + default='') + parser.add_argument('--output_folder', + type=str, + default='', + help='folder name for training outputs.') + + parser.add_argument('--use_cuda', + type=bool, + default=False, + help="enable/disable cuda.") + + parser.add_argument( + '--batch_size', + default=16, + type=int, + help='Batch size for the model. Use batch_size=1 if you have no CUDA.') + args = parser.parse_args() + + C = load_config(args.config_path) + ap = AudioProcessor(**C.audio) + + # if the vocabulary was passed, replace the default + if 'characters' in C.keys(): + symbols, phonemes = make_symbols(**C.characters) + + # load the model + num_chars = len(phonemes) if C.use_phonemes else len(symbols) + # TODO: handle multi-speaker + model = setup_model(num_chars, num_speakers=0, c=C) + model, _ = load_checkpoint(model, args.model_path, None, args.use_cuda) + model.eval() + + # data loader + preprocessor = importlib.import_module('TTS.tts.datasets.preprocess') + preprocessor = getattr(preprocessor, args.dataset) + meta_data = preprocessor(args.data_path, args.dataset_metafile) + dataset = MyDataset(model.decoder.r, + C.text_cleaner, + compute_linear_spec=False, + ap=ap, + meta_data=meta_data, + tp=C.characters if 'characters' in C.keys() else None, + add_blank=c['add_blank'] if 'add_blank' in C.keys() else False, + use_phonemes=C.use_phonemes, + phoneme_cache_path=C.phoneme_cache_path, + phoneme_language=C.phoneme_language, + enable_eos_bos=C.enable_eos_bos_chars) + + dataset.sort_items() + loader = DataLoader(dataset, + batch_size=args.batch_size, + num_workers=4, + collate_fn=dataset.collate_fn, + shuffle=False, + drop_last=False) + + # compute attentions + file_paths = [] + with torch.no_grad(): + for data in tqdm(loader): + # setup input data + text_input = data[0] + text_lengths = data[1] + linear_input = data[3] + mel_input = data[4] + mel_lengths = data[5] + stop_targets = data[6] + item_idxs = data[7] + + # dispatch data to GPU + if args.use_cuda: + text_input = text_input.cuda() + text_lengths = text_lengths.cuda() + mel_input = mel_input.cuda() + mel_lengths = mel_lengths.cuda() + + mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward( + text_input, text_lengths, mel_input) + + alignments = alignments.detach() + for idx, alignment in enumerate(alignments): + item_idx = item_idxs[idx] + # interpolate if r > 1 + alignment = torch.nn.functional.interpolate( + alignment.transpose(0, 1).unsqueeze(0), + size=None, + scale_factor=model.decoder.r, + mode='nearest', + align_corners=None, + recompute_scale_factor=None).squeeze(0).transpose(0, 1) + + # remove paddings + alignment = alignment[:mel_lengths[idx], :text_lengths[idx]].cpu().numpy() + + # set file paths + wav_file_name = os.path.basename(item_idx) + align_file_name = os.path.splitext(wav_file_name)[0] + '.npy' + file_path = item_idx.replace(wav_file_name, align_file_name) + # save output + file_paths.append([item_idx, file_path]) + np.save(file_path, alignment) + + # ourpur metafile + metafile = os.path.join(args.data_path, "metadata_attn_mask.txt") + + with open(metafile, "w") as f: + for p in file_paths: + f.write(f"{p[0]}|{p[1]}\n") + print(f" >> Metafile created: {metafile}")