From 0e79fa86ad800af6fa4cd78edbc4b27d7a478d5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Fri, 9 Apr 2021 00:38:08 +0200 Subject: [PATCH] format with black and pylint 2.7.3 --- TTS/bin/compute_attention_masks.py | 139 +++-- TTS/bin/compute_embeddings.py | 81 ++- TTS/bin/compute_statistics.py | 51 +- TTS/bin/convert_melgan_tflite.py | 12 +- TTS/bin/convert_melgan_torch_to_tf.py | 43 +- TTS/bin/convert_tacotron2_tflite.py | 12 +- TTS/bin/convert_tacotron2_torch_to_tf.py | 129 ++--- TTS/bin/distribute.py | 41 +- TTS/bin/find_unique_chars.py | 32 +- TTS/bin/resample.py | 66 +-- TTS/bin/synthesize.py | 111 ++-- TTS/bin/train_align_tts.py | 366 ++++++-------- TTS/bin/train_encoder.py | 153 +++--- TTS/bin/train_glow_tts.py | 263 +++++----- TTS/bin/train_speedy_speech.py | 263 +++++----- TTS/bin/train_tacotron.py | 407 +++++++++------ TTS/bin/train_vocoder_gan.py | 266 +++++----- TTS/bin/train_vocoder_wavegrad.py | 170 +++---- TTS/bin/train_vocoder_wavernn.py | 162 +++--- TTS/bin/tune_wavegrad.py | 58 ++- TTS/server/server.py | 80 +-- TTS/speaker_encoder/dataset.py | 39 +- TTS/speaker_encoder/losses.py | 29 +- TTS/speaker_encoder/model.py | 12 +- TTS/speaker_encoder/utils/generic_utils.py | 138 ++--- TTS/speaker_encoder/utils/prepare_voxceleb.py | 84 ++-- TTS/tts/datasets/TTSDataset.py | 221 ++++---- TTS/tts/datasets/preprocess.py | 164 +++--- TTS/tts/layers/align_tts/mdn.py | 5 +- TTS/tts/layers/feed_forward/decoder.py | 77 +-- .../layers/feed_forward/duration_predictor.py | 15 +- TTS/tts/layers/feed_forward/encoder.py | 90 ++-- TTS/tts/layers/generic/gated_conv.py | 12 +- TTS/tts/layers/generic/normalization.py | 29 +- TTS/tts/layers/generic/pos_encoding.py | 19 +- TTS/tts/layers/generic/res_conv_bn.py | 36 +- TTS/tts/layers/generic/time_depth_sep_conv.py | 30 +- TTS/tts/layers/generic/transformer.py | 37 +- TTS/tts/layers/generic/wavenet.py | 94 ++-- TTS/tts/layers/glow_tts/decoder.py | 70 ++- TTS/tts/layers/glow_tts/duration_predictor.py | 11 +- TTS/tts/layers/glow_tts/encoder.py | 94 ++-- TTS/tts/layers/glow_tts/glow.py | 87 ++-- .../glow_tts/monotonic_align/__init__.py | 6 +- TTS/tts/layers/glow_tts/transformer.py | 190 ++++--- TTS/tts/layers/losses.py | 192 ++++--- TTS/tts/layers/tacotron/attentions.py | 217 ++++---- TTS/tts/layers/tacotron/common_layers.py | 56 +-- TTS/tts/layers/tacotron/gst_layers.py | 75 +-- TTS/tts/layers/tacotron/tacotron.py | 240 +++++---- TTS/tts/layers/tacotron/tacotron2.py | 188 ++++--- TTS/tts/models/align_tts.py | 138 ++--- TTS/tts/models/glow_tts.py | 161 +++--- TTS/tts/models/speedy_speech.py | 75 ++- TTS/tts/models/tacotron.py | 201 +++++--- TTS/tts/models/tacotron2.py | 211 +++++--- TTS/tts/models/tacotron_abstract.py | 104 ++-- TTS/tts/tf/layers/tacotron/common_layers.py | 103 ++-- TTS/tts/tf/layers/tacotron/tacotron2.py | 203 ++++---- TTS/tts/tf/models/tacotron2.py | 99 ++-- TTS/tts/tf/utils/convert_torch_to_tf_utils.py | 45 +- TTS/tts/tf/utils/generic_utils.py | 79 ++- TTS/tts/tf/utils/io.py | 22 +- TTS/tts/tf/utils/tflite.py | 15 +- TTS/tts/utils/chinese_mandarin/numbers.py | 54 +- TTS/tts/utils/chinese_mandarin/phonemizer.py | 4 +- .../chinese_mandarin/pinyinToPhonemes.py | 1 - TTS/tts/utils/data.py | 27 +- TTS/tts/utils/generic_utils.py | 475 ++++++++++-------- TTS/tts/utils/io.py | 61 +-- TTS/tts/utils/speakers.py | 42 +- TTS/tts/utils/ssim.py | 25 +- TTS/tts/utils/synthesis.py | 196 ++++---- TTS/tts/utils/text/__init__.py | 101 ++-- TTS/tts/utils/text/abbreviations.py | 129 ++--- TTS/tts/utils/text/cleaners.py | 64 +-- TTS/tts/utils/text/cmudict.py | 119 ++++- TTS/tts/utils/text/number_norm.py | 33 +- TTS/tts/utils/text/symbols.py | 53 +- TTS/tts/utils/text/time.py | 6 +- TTS/tts/utils/visual.py | 94 ++-- TTS/utils/arguments.py | 78 ++- TTS/utils/audio.py | 176 +++---- TTS/utils/console_logger.py | 61 +-- TTS/utils/distribute.py | 18 +- TTS/utils/generic_utils.py | 66 +-- TTS/utils/io.py | 22 +- TTS/utils/manage.py | 25 +- TTS/utils/radam.py | 69 +-- TTS/utils/synthesizer.py | 29 +- TTS/utils/tensorboard_logger.py | 32 +- TTS/utils/training.py | 39 +- TTS/vocoder/datasets/gan_dataset.py | 57 ++- TTS/vocoder/datasets/preprocess.py | 6 +- TTS/vocoder/datasets/wavegrad_dataset.py | 44 +- TTS/vocoder/datasets/wavernn_dataset.py | 51 +- TTS/vocoder/layers/losses.py | 206 ++++---- TTS/vocoder/layers/melgan.py | 35 +- TTS/vocoder/layers/parallel_wavegan.py | 56 +-- TTS/vocoder/layers/pqmf.py | 16 +- TTS/vocoder/layers/upsample.py | 57 +-- TTS/vocoder/layers/wavegrad.py | 53 +- .../models/fullband_melgan_generator.py | 39 +- TTS/vocoder/models/melgan_discriminator.py | 69 +-- TTS/vocoder/models/melgan_generator.py | 79 ++- .../models/melgan_multiscale_discriminator.py | 53 +- .../models/multiband_melgan_generator.py | 40 +- .../models/parallel_wavegan_discriminator.py | 118 ++--- .../models/parallel_wavegan_generator.py | 92 ++-- .../models/random_window_discriminator.py | 105 ++-- TTS/vocoder/models/wavegrad.py | 69 +-- TTS/vocoder/models/wavernn.py | 127 ++--- TTS/vocoder/tf/layers/melgan.py | 32 +- TTS/vocoder/tf/layers/pqmf.py | 25 +- TTS/vocoder/tf/models/melgan_generator.py | 74 +-- .../tf/models/multiband_melgan_generator.py | 42 +- .../tf/utils/convert_torch_to_tf_utils.py | 26 +- TTS/vocoder/tf/utils/generic_utils.py | 25 +- TTS/vocoder/tf/utils/io.py | 14 +- TTS/vocoder/tf/utils/tflite.py | 15 +- TTS/vocoder/utils/distribution.py | 37 +- TTS/vocoder/utils/generic_utils.py | 153 +++--- TTS/vocoder/utils/io.py | 128 +++-- pyproject.toml | 26 + tests/test_audio.py | 2 +- tests/test_loader.py | 2 +- tests/test_vocoder_losses.py | 12 +- 127 files changed, 5511 insertions(+), 5491 deletions(-) diff --git a/TTS/bin/compute_attention_masks.py b/TTS/bin/compute_attention_masks.py index 53246e07..2ac725bf 100644 --- a/TTS/bin/compute_attention_masks.py +++ b/TTS/bin/compute_attention_masks.py @@ -15,16 +15,14 @@ from TTS.utils.audio import AudioProcessor from TTS.utils.io import load_config -if __name__ == '__main__': - # pylint: disable=bad-continuation +if __name__ == "__main__": + # pylint: disable=bad-option-value parser = argparse.ArgumentParser( - description='''Extract attention masks from trained Tacotron/Tacotron2 models. -These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n''' - -'''Each attention mask is written to the same path as the input wav file with ".npy" file extension. -(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n''' - -''' + description="""Extract attention masks from trained Tacotron/Tacotron2 models. +These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n""" + """Each attention mask is written to the same path as the input wav file with ".npy" file extension. +(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n""" + """ Example run: CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py --model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar @@ -34,53 +32,44 @@ Example run: --batch_size 32 --dataset ljspeech --use_cuda True -''', - formatter_class=RawTextHelpFormatter - ) - parser.add_argument('--model_path', - type=str, - required=True, - help='Path to Tacotron/Tacotron2 model file ') - parser.add_argument( - '--config_path', - type=str, - required=True, - help='Path to Tacotron/Tacotron2 config file.', +""", + formatter_class=RawTextHelpFormatter, ) - parser.add_argument('--dataset', - type=str, - default='', - required=True, - help='Target dataset processor name from TTS.tts.dataset.preprocess.') - + parser.add_argument("--model_path", type=str, required=True, help="Path to Tacotron/Tacotron2 model file ") parser.add_argument( - '--dataset_metafile', + "--config_path", type=str, - default='', required=True, - help='Dataset metafile inclusing file paths with transcripts.') + help="Path to Tacotron/Tacotron2 config file.", + ) parser.add_argument( - '--data_path', + "--dataset", type=str, - default='', - help='Defines the data path. It overwrites config.json.') - parser.add_argument('--use_cuda', - type=bool, - default=False, - help="enable/disable cuda.") + default="", + required=True, + help="Target dataset processor name from TTS.tts.dataset.preprocess.", + ) parser.add_argument( - '--batch_size', - default=16, - type=int, - help='Batch size for the model. Use batch_size=1 if you have no CUDA.') + "--dataset_metafile", + type=str, + default="", + required=True, + help="Dataset metafile inclusing file paths with transcripts.", + ) + parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.") + parser.add_argument("--use_cuda", type=bool, default=False, help="enable/disable cuda.") + + parser.add_argument( + "--batch_size", default=16, type=int, help="Batch size for the model. Use batch_size=1 if you have no CUDA." + ) args = parser.parse_args() C = load_config(args.config_path) ap = AudioProcessor(**C.audio) # if the vocabulary was passed, replace the default - if 'characters' in C.keys(): + if "characters" in C.keys(): symbols, phonemes = make_symbols(**C.characters) # load the model @@ -91,28 +80,32 @@ Example run: model.eval() # data loader - preprocessor = importlib.import_module('TTS.tts.datasets.preprocess') + preprocessor = importlib.import_module("TTS.tts.datasets.preprocess") preprocessor = getattr(preprocessor, args.dataset) meta_data = preprocessor(args.data_path, args.dataset_metafile) - dataset = MyDataset(model.decoder.r, - C.text_cleaner, - compute_linear_spec=False, - ap=ap, - meta_data=meta_data, - tp=C.characters if 'characters' in C.keys() else None, - add_blank=C['add_blank'] if 'add_blank' in C.keys() else False, - use_phonemes=C.use_phonemes, - phoneme_cache_path=C.phoneme_cache_path, - phoneme_language=C.phoneme_language, - enable_eos_bos=C.enable_eos_bos_chars) + dataset = MyDataset( + model.decoder.r, + C.text_cleaner, + compute_linear_spec=False, + ap=ap, + meta_data=meta_data, + tp=C.characters if "characters" in C.keys() else None, + add_blank=C["add_blank"] if "add_blank" in C.keys() else False, + use_phonemes=C.use_phonemes, + phoneme_cache_path=C.phoneme_cache_path, + phoneme_language=C.phoneme_language, + enable_eos_bos=C.enable_eos_bos_chars, + ) dataset.sort_items() - loader = DataLoader(dataset, - batch_size=args.batch_size, - num_workers=4, - collate_fn=dataset.collate_fn, - shuffle=False, - drop_last=False) + loader = DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=4, + collate_fn=dataset.collate_fn, + shuffle=False, + drop_last=False, + ) # compute attentions file_paths = [] @@ -134,25 +127,29 @@ Example run: mel_input = mel_input.cuda() mel_lengths = mel_lengths.cuda() - mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward( - text_input, text_lengths, mel_input) + mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input) alignments = alignments.detach() for idx, alignment in enumerate(alignments): item_idx = item_idxs[idx] # interpolate if r > 1 - alignment = torch.nn.functional.interpolate( - alignment.transpose(0, 1).unsqueeze(0), - size=None, - scale_factor=model.decoder.r, - mode='nearest', - align_corners=None, - recompute_scale_factor=None).squeeze(0).transpose(0, 1) + alignment = ( + torch.nn.functional.interpolate( + alignment.transpose(0, 1).unsqueeze(0), + size=None, + scale_factor=model.decoder.r, + mode="nearest", + align_corners=None, + recompute_scale_factor=None, + ) + .squeeze(0) + .transpose(0, 1) + ) # remove paddings - alignment = alignment[:mel_lengths[idx], :text_lengths[idx]].cpu().numpy() + alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy() # set file paths wav_file_name = os.path.basename(item_idx) - align_file_name = os.path.splitext(wav_file_name)[0] + '.npy' + align_file_name = os.path.splitext(wav_file_name)[0] + ".npy" file_path = item_idx.replace(wav_file_name, align_file_name) # save output file_paths.append([item_idx, file_path]) diff --git a/TTS/bin/compute_embeddings.py b/TTS/bin/compute_embeddings.py index 64edd140..36ecb0f0 100644 --- a/TTS/bin/compute_embeddings.py +++ b/TTS/bin/compute_embeddings.py @@ -13,91 +13,72 @@ from TTS.tts.utils.speakers import save_speaker_mapping from TTS.tts.datasets.preprocess import load_meta_data parser = argparse.ArgumentParser( - description='Compute embedding vectors for each wav file in a dataset. If "target_dataset" is defined, it generates "speakers.json" necessary for training a multi-speaker model.') -parser.add_argument( - 'model_path', - type=str, - help='Path to model outputs (checkpoint, tensorboard etc.).') -parser.add_argument( - 'config_path', - type=str, - help='Path to config file for training.', + description='Compute embedding vectors for each wav file in a dataset. If "target_dataset" is defined, it generates "speakers.json" necessary for training a multi-speaker model.' ) +parser.add_argument("model_path", type=str, help="Path to model outputs (checkpoint, tensorboard etc.).") parser.add_argument( - 'data_path', + "config_path", type=str, - help='Data path for wav files - directory or CSV file') + help="Path to config file for training.", +) +parser.add_argument("data_path", type=str, help="Data path for wav files - directory or CSV file") +parser.add_argument("output_path", type=str, help="path for training outputs.") parser.add_argument( - 'output_path', + "--target_dataset", type=str, - help='path for training outputs.') -parser.add_argument( - '--target_dataset', - type=str, - default='', - help='Target dataset to pick a processor from TTS.tts.dataset.preprocess. Necessary to create a speakers.json file.' -) -parser.add_argument( - '--use_cuda', type=bool, help='flag to set cuda.', default=False -) -parser.add_argument( - '--separator', type=str, help='Separator used in file if CSV is passed for data_path', default='|' + default="", + help="Target dataset to pick a processor from TTS.tts.dataset.preprocess. Necessary to create a speakers.json file.", ) +parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=False) +parser.add_argument("--separator", type=str, help="Separator used in file if CSV is passed for data_path", default="|") args = parser.parse_args() c = load_config(args.config_path) -ap = AudioProcessor(**c['audio']) +ap = AudioProcessor(**c["audio"]) data_path = args.data_path split_ext = os.path.splitext(data_path) sep = args.separator -if args.target_dataset != '': +if args.target_dataset != "": # if target dataset is defined dataset_config = [ - { - "name": args.target_dataset, - "path": args.data_path, - "meta_file_train": None, - "meta_file_val": None - }, + {"name": args.target_dataset, "path": args.data_path, "meta_file_train": None, "meta_file_val": None}, ] wav_files, _ = load_meta_data(dataset_config, eval_split=False) - output_files = [wav_file[1].replace(data_path, args.output_path).replace( - '.wav', '.npy') for wav_file in wav_files] + output_files = [wav_file[1].replace(data_path, args.output_path).replace(".wav", ".npy") for wav_file in wav_files] else: # if target dataset is not defined - if len(split_ext) > 0 and split_ext[1].lower() == '.csv': + if len(split_ext) > 0 and split_ext[1].lower() == ".csv": # Parse CSV - print(f'CSV file: {data_path}') + print(f"CSV file: {data_path}") with open(data_path) as f: - wav_path = os.path.join(os.path.dirname(data_path), 'wavs') + wav_path = os.path.join(os.path.dirname(data_path), "wavs") wav_files = [] - print(f'Separator is: {sep}') + print(f"Separator is: {sep}") for line in f: components = line.split(sep) if len(components) != 2: print("Invalid line") continue - wav_file = os.path.join(wav_path, components[0] + '.wav') - #print(f'wav_file: {wav_file}') + wav_file = os.path.join(wav_path, components[0] + ".wav") + # print(f'wav_file: {wav_file}') if os.path.exists(wav_file): wav_files.append(wav_file) - print(f'Count of wavs imported: {len(wav_files)}') + print(f"Count of wavs imported: {len(wav_files)}") else: # Parse all wav files in data_path - wav_files = glob.glob(data_path + '/**/*.wav', recursive=True) + wav_files = glob.glob(data_path + "/**/*.wav", recursive=True) - output_files = [wav_file.replace(data_path, args.output_path).replace( - '.wav', '.npy') for wav_file in wav_files] + output_files = [wav_file.replace(data_path, args.output_path).replace(".wav", ".npy") for wav_file in wav_files] for output_file in output_files: os.makedirs(os.path.dirname(output_file), exist_ok=True) # define Encoder model model = SpeakerEncoder(**c.model) -model.load_state_dict(torch.load(args.model_path)['model']) +model.load_state_dict(torch.load(args.model_path)["model"]) model.eval() if args.use_cuda: model.cuda() @@ -117,14 +98,14 @@ for idx, wav_file in enumerate(tqdm(wav_files)): embedd = embedd.detach().cpu().numpy() np.save(output_files[idx], embedd) - if args.target_dataset != '': + if args.target_dataset != "": # create speaker_mapping if target dataset is defined wav_file_name = os.path.basename(wav_file) speaker_mapping[wav_file_name] = {} - speaker_mapping[wav_file_name]['name'] = speaker_name - speaker_mapping[wav_file_name]['embedding'] = embedd.flatten().tolist() + speaker_mapping[wav_file_name]["name"] = speaker_name + speaker_mapping[wav_file_name]["embedding"] = embedd.flatten().tolist() -if args.target_dataset != '': +if args.target_dataset != "": # save speaker_mapping if target dataset is defined - mapping_file_path = os.path.join(args.output_path, 'speakers.json') + mapping_file_path = os.path.join(args.output_path, "speakers.json") save_speaker_mapping(args.output_path, speaker_mapping) diff --git a/TTS/bin/compute_statistics.py b/TTS/bin/compute_statistics.py index a74fe90a..ce224310 100755 --- a/TTS/bin/compute_statistics.py +++ b/TTS/bin/compute_statistics.py @@ -15,25 +15,24 @@ from TTS.utils.audio import AudioProcessor def main(): """Run preprocessing process.""" - parser = argparse.ArgumentParser( - description="Compute mean and variance of spectrogtram features.") - parser.add_argument("--config_path", type=str, required=True, - help="TTS config file path to define audio processin parameters.") - parser.add_argument("--out_path", type=str, required=True, - help="save path (directory and filename).") + parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.") + parser.add_argument( + "--config_path", type=str, required=True, help="TTS config file path to define audio processin parameters." + ) + parser.add_argument("--out_path", type=str, required=True, help="save path (directory and filename).") args = parser.parse_args() # load config CONFIG = load_config(args.config_path) - CONFIG.audio['signal_norm'] = False # do not apply earlier normalization - CONFIG.audio['stats_path'] = None # discard pre-defined stats + CONFIG.audio["signal_norm"] = False # do not apply earlier normalization + CONFIG.audio["stats_path"] = None # discard pre-defined stats # load audio processor ap = AudioProcessor(**CONFIG.audio) # load the meta data of target dataset - if 'data_path' in CONFIG.keys(): - dataset_items = glob.glob(os.path.join(CONFIG.data_path, '**', '*.wav'), recursive=True) + if "data_path" in CONFIG.keys(): + dataset_items = glob.glob(os.path.join(CONFIG.data_path, "**", "*.wav"), recursive=True) else: dataset_items = load_meta_data(CONFIG.datasets)[0] # take only train data print(f" > There are {len(dataset_items)} files.") @@ -63,27 +62,27 @@ def main(): output_file_path = args.out_path stats = {} - stats['mel_mean'] = mel_mean - stats['mel_std'] = mel_scale - stats['linear_mean'] = linear_mean - stats['linear_std'] = linear_scale + stats["mel_mean"] = mel_mean + stats["mel_std"] = mel_scale + stats["linear_mean"] = linear_mean + stats["linear_std"] = linear_scale - print(f' > Avg mel spec mean: {mel_mean.mean()}') - print(f' > Avg mel spec scale: {mel_scale.mean()}') - print(f' > Avg linear spec mean: {linear_mean.mean()}') - print(f' > Avg lienar spec scale: {linear_scale.mean()}') + print(f" > Avg mel spec mean: {mel_mean.mean()}") + print(f" > Avg mel spec scale: {mel_scale.mean()}") + print(f" > Avg linear spec mean: {linear_mean.mean()}") + print(f" > Avg lienar spec scale: {linear_scale.mean()}") # set default config values for mean-var scaling - CONFIG.audio['stats_path'] = output_file_path - CONFIG.audio['signal_norm'] = True + CONFIG.audio["stats_path"] = output_file_path + CONFIG.audio["signal_norm"] = True # remove redundant values - del CONFIG.audio['max_norm'] - del CONFIG.audio['min_level_db'] - del CONFIG.audio['symmetric_norm'] - del CONFIG.audio['clip_norm'] - stats['audio_config'] = CONFIG.audio + del CONFIG.audio["max_norm"] + del CONFIG.audio["min_level_db"] + del CONFIG.audio["symmetric_norm"] + del CONFIG.audio["clip_norm"] + stats["audio_config"] = CONFIG.audio np.save(output_file_path, stats, allow_pickle=True) - print(f' > stats saved to {output_file_path}') + print(f" > stats saved to {output_file_path}") if __name__ == "__main__": diff --git a/TTS/bin/convert_melgan_tflite.py b/TTS/bin/convert_melgan_tflite.py index 8df582da..06784abe 100644 --- a/TTS/bin/convert_melgan_tflite.py +++ b/TTS/bin/convert_melgan_tflite.py @@ -9,15 +9,9 @@ from TTS.vocoder.tf.utils.tflite import convert_melgan_to_tflite parser = argparse.ArgumentParser() -parser.add_argument('--tf_model', - type=str, - help='Path to target torch model to be converted to TF.') -parser.add_argument('--config_path', - type=str, - help='Path to config file of torch model.') -parser.add_argument('--output_path', - type=str, - help='path to tflite output binary.') +parser.add_argument("--tf_model", type=str, help="Path to target torch model to be converted to TF.") +parser.add_argument("--config_path", type=str, help="Path to config file of torch model.") +parser.add_argument("--output_path", type=str, help="path to tflite output binary.") args = parser.parse_args() # Set constants diff --git a/TTS/bin/convert_melgan_torch_to_tf.py b/TTS/bin/convert_melgan_torch_to_tf.py index 2eec6157..176bb992 100644 --- a/TTS/bin/convert_melgan_torch_to_tf.py +++ b/TTS/bin/convert_melgan_torch_to_tf.py @@ -8,27 +8,22 @@ import torch from TTS.utils.io import load_config from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import ( - compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf) -from TTS.vocoder.tf.utils.generic_utils import \ - setup_generator as setup_tf_generator + compare_torch_tf, + convert_tf_name, + transfer_weights_torch_to_tf, +) +from TTS.vocoder.tf.utils.generic_utils import setup_generator as setup_tf_generator from TTS.vocoder.tf.utils.io import save_checkpoint from TTS.vocoder.utils.generic_utils import setup_generator # prevent GPU use -os.environ['CUDA_VISIBLE_DEVICES'] = '' +os.environ["CUDA_VISIBLE_DEVICES"] = "" # define args parser = argparse.ArgumentParser() -parser.add_argument('--torch_model_path', - type=str, - help='Path to target torch model to be converted to TF.') -parser.add_argument('--config_path', - type=str, - help='Path to config file of torch model.') -parser.add_argument( - '--output_path', - type=str, - help='path to output file including file name to save TF model.') +parser.add_argument("--torch_model_path", type=str, help="Path to target torch model to be converted to TF.") +parser.add_argument("--config_path", type=str, help="Path to config file of torch model.") +parser.add_argument("--output_path", type=str, help="path to output file including file name to save TF model.") args = parser.parse_args() # load model config @@ -38,9 +33,8 @@ num_speakers = 0 # init torch model model = setup_generator(c) -checkpoint = torch.load(args.torch_model_path, - map_location=torch.device('cpu')) -state_dict = checkpoint['model'] +checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu")) +state_dict = checkpoint["model"] model.load_state_dict(state_dict) model.remove_weight_norm() state_dict = model.state_dict() @@ -48,7 +42,7 @@ state_dict = model.state_dict() # init tf model model_tf = setup_tf_generator(c) -common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE' +common_sufix = "/.ATTRIBUTES/VARIABLE_VALUE" # get tf_model graph by passing an input # B x D x T dummy_input = tf.random.uniform((7, 80, 64), dtype=tf.float32) @@ -66,10 +60,7 @@ for tf_name in tf_var_names: if tf_name in [name[0] for name in var_map]: continue tf_name_edited = convert_tf_name(tf_name) - ratios = [ - SequenceMatcher(None, torch_name, tf_name_edited).ratio() - for torch_name in torch_var_names - ] + ratios = [SequenceMatcher(None, torch_name, tf_name_edited).ratio() for torch_name in torch_var_names] max_idx = np.argmax(ratios) matching_name = torch_var_names[max_idx] del torch_var_names[max_idx] @@ -107,10 +98,8 @@ model.inference_padding = 0 model_tf.inference_padding = 0 output_torch = model.inference(dummy_input_torch) output_tf = model_tf(dummy_input_tf, training=False) -assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf( - output_torch, output_tf) +assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf(output_torch, output_tf) # save tf model -save_checkpoint(model_tf, checkpoint['step'], checkpoint['epoch'], - args.output_path) -print(' > Model conversion is successfully completed :).') +save_checkpoint(model_tf, checkpoint["step"], checkpoint["epoch"], args.output_path) +print(" > Model conversion is successfully completed :).") diff --git a/TTS/bin/convert_tacotron2_tflite.py b/TTS/bin/convert_tacotron2_tflite.py index 2fddf4b0..2a7926a8 100644 --- a/TTS/bin/convert_tacotron2_tflite.py +++ b/TTS/bin/convert_tacotron2_tflite.py @@ -10,15 +10,9 @@ from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite parser = argparse.ArgumentParser() -parser.add_argument('--tf_model', - type=str, - help='Path to target torch model to be converted to TF.') -parser.add_argument('--config_path', - type=str, - help='Path to config file of torch model.') -parser.add_argument('--output_path', - type=str, - help='path to tflite output binary.') +parser.add_argument("--tf_model", type=str, help="Path to target torch model to be converted to TF.") +parser.add_argument("--config_path", type=str, help="Path to config file of torch model.") +parser.add_argument("--output_path", type=str, help="path to tflite output binary.") args = parser.parse_args() # Set constants diff --git a/TTS/bin/convert_tacotron2_torch_to_tf.py b/TTS/bin/convert_tacotron2_torch_to_tf.py index 71fb8d5e..b4aafa9e 100644 --- a/TTS/bin/convert_tacotron2_torch_to_tf.py +++ b/TTS/bin/convert_tacotron2_torch_to_tf.py @@ -8,27 +8,20 @@ import numpy as np import tensorflow as tf import torch from TTS.tts.tf.models.tacotron2 import Tacotron2 -from TTS.tts.tf.utils.convert_torch_to_tf_utils import ( - compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf) +from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf from TTS.tts.tf.utils.generic_utils import save_checkpoint from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.text.symbols import phonemes, symbols from TTS.utils.io import load_config -sys.path.append('/home/erogol/Projects') -os.environ['CUDA_VISIBLE_DEVICES'] = '' +sys.path.append("/home/erogol/Projects") +os.environ["CUDA_VISIBLE_DEVICES"] = "" parser = argparse.ArgumentParser() -parser.add_argument('--torch_model_path', - type=str, - help='Path to target torch model to be converted to TF.') -parser.add_argument('--config_path', - type=str, - help='Path to config file of torch model.') -parser.add_argument('--output_path', - type=str, - help='path to output file including file name to save TF model.') +parser.add_argument("--torch_model_path", type=str, help="Path to target torch model to be converted to TF.") +parser.add_argument("--config_path", type=str, help="Path to config file of torch model.") +parser.add_argument("--output_path", type=str, help="path to output file including file name to save TF model.") args = parser.parse_args() # load model config @@ -39,51 +32,48 @@ num_speakers = 0 # init torch model num_chars = len(phonemes) if c.use_phonemes else len(symbols) model = setup_model(num_chars, num_speakers, c) -checkpoint = torch.load(args.torch_model_path, - map_location=torch.device('cpu')) -state_dict = checkpoint['model'] +checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu")) +state_dict = checkpoint["model"] model.load_state_dict(state_dict) # init tf model -model_tf = Tacotron2(num_chars=num_chars, - num_speakers=num_speakers, - r=model.decoder.r, - postnet_output_dim=c.audio['num_mels'], - decoder_output_dim=c.audio['num_mels'], - attn_type=c.attention_type, - attn_win=c.windowing, - attn_norm=c.attention_norm, - prenet_type=c.prenet_type, - prenet_dropout=c.prenet_dropout, - forward_attn=c.use_forward_attn, - trans_agent=c.transition_agent, - forward_attn_mask=c.forward_attn_mask, - location_attn=c.location_attn, - attn_K=c.attention_heads, - separate_stopnet=c.separate_stopnet, - bidirectional_decoder=c.bidirectional_decoder) +model_tf = Tacotron2( + num_chars=num_chars, + num_speakers=num_speakers, + r=model.decoder.r, + postnet_output_dim=c.audio["num_mels"], + decoder_output_dim=c.audio["num_mels"], + attn_type=c.attention_type, + attn_win=c.windowing, + attn_norm=c.attention_norm, + prenet_type=c.prenet_type, + prenet_dropout=c.prenet_dropout, + forward_attn=c.use_forward_attn, + trans_agent=c.transition_agent, + forward_attn_mask=c.forward_attn_mask, + location_attn=c.location_attn, + attn_K=c.attention_heads, + separate_stopnet=c.separate_stopnet, + bidirectional_decoder=c.bidirectional_decoder, +) # set initial layer mapping - these are not captured by the below heuristic approach # TODO: set layer names so that we can remove these manual matching -common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE' +common_sufix = "/.ATTRIBUTES/VARIABLE_VALUE" var_map = [ - ('embedding/embeddings:0', 'embedding.weight'), - ('encoder/lstm/forward_lstm/lstm_cell_1/kernel:0', - 'encoder.lstm.weight_ih_l0'), - ('encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0', - 'encoder.lstm.weight_hh_l0'), - ('encoder/lstm/backward_lstm/lstm_cell_2/kernel:0', - 'encoder.lstm.weight_ih_l0_reverse'), - ('encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0', - 'encoder.lstm.weight_hh_l0_reverse'), - ('encoder/lstm/forward_lstm/lstm_cell_1/bias:0', - ('encoder.lstm.bias_ih_l0', 'encoder.lstm.bias_hh_l0')), - ('encoder/lstm/backward_lstm/lstm_cell_2/bias:0', - ('encoder.lstm.bias_ih_l0_reverse', 'encoder.lstm.bias_hh_l0_reverse')), - ('attention/v/kernel:0', 'decoder.attention.v.linear_layer.weight'), - ('decoder/linear_projection/kernel:0', - 'decoder.linear_projection.linear_layer.weight'), - ('decoder/stopnet/kernel:0', 'decoder.stopnet.1.linear_layer.weight') + ("embedding/embeddings:0", "embedding.weight"), + ("encoder/lstm/forward_lstm/lstm_cell_1/kernel:0", "encoder.lstm.weight_ih_l0"), + ("encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0", "encoder.lstm.weight_hh_l0"), + ("encoder/lstm/backward_lstm/lstm_cell_2/kernel:0", "encoder.lstm.weight_ih_l0_reverse"), + ("encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0", "encoder.lstm.weight_hh_l0_reverse"), + ("encoder/lstm/forward_lstm/lstm_cell_1/bias:0", ("encoder.lstm.bias_ih_l0", "encoder.lstm.bias_hh_l0")), + ( + "encoder/lstm/backward_lstm/lstm_cell_2/bias:0", + ("encoder.lstm.bias_ih_l0_reverse", "encoder.lstm.bias_hh_l0_reverse"), + ), + ("attention/v/kernel:0", "decoder.attention.v.linear_layer.weight"), + ("decoder/linear_projection/kernel:0", "decoder.linear_projection.linear_layer.weight"), + ("decoder/stopnet/kernel:0", "decoder.stopnet.1.linear_layer.weight"), ] # %% @@ -101,10 +91,7 @@ for tf_name in tf_var_names: if tf_name in [name[0] for name in var_map]: continue tf_name_edited = convert_tf_name(tf_name) - ratios = [ - SequenceMatcher(None, torch_name, tf_name_edited).ratio() - for torch_name in torch_var_names - ] + ratios = [SequenceMatcher(None, torch_name, tf_name_edited).ratio() for torch_name in torch_var_names] max_idx = np.argmax(ratios) matching_name = torch_var_names[max_idx] del torch_var_names[max_idx] @@ -124,25 +111,21 @@ input_ids = torch.randint(0, 24, (1, 128)).long() o_t = model.embedding(input_ids) o_tf = model_tf.embedding(input_ids.detach().numpy()) -assert abs(o_t.detach().numpy() - - o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() - - o_tf.numpy()).sum() +assert abs(o_t.detach().numpy() - o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() - o_tf.numpy()).sum() # compare encoder outputs oo_en = model.encoder.inference(o_t.transpose(1, 2)) ooo_en = model_tf.encoder(o_t.detach().numpy(), training=False) assert compare_torch_tf(oo_en, ooo_en) < 1e-5 -#pylint: disable=redefined-builtin +# pylint: disable=redefined-builtin # compare decoder.attention_rnn inp = torch.rand([1, 768]) inp_tf = inp.numpy() -model.decoder._init_states(oo_en, mask=None) #pylint: disable=protected-access +model.decoder._init_states(oo_en, mask=None) # pylint: disable=protected-access output, cell_state = model.decoder.attention_rnn(inp) states = model_tf.decoder.build_decoder_initial_states(1, 512, 128) -output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf, - states[2], - training=False) +output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf, states[2], training=False) assert compare_torch_tf(output, output_tf).mean() < 1e-5 query = output @@ -153,8 +136,7 @@ inputs_tf = inputs.numpy() # compare decoder.attention model.decoder.attention.init_states(inputs) processes_inputs = model.decoder.attention.preprocess_inputs(inputs) -loc_attn, proc_query = model.decoder.attention.get_location_attention( - query, processes_inputs) +loc_attn, proc_query = model.decoder.attention.get_location_attention(query, processes_inputs) context = model.decoder.attention(query, inputs, processes_inputs, None) attention_states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)[-1] @@ -169,13 +151,10 @@ assert compare_torch_tf(context, context_tf) < 1e-5 # compare decoder.decoder_rnn input = torch.rand([1, 1536]) input_tf = input.numpy() -model.decoder._init_states(oo_en, mask=None) #pylint: disable=protected-access -output, cell_state = model.decoder.decoder_rnn( - input, [model.decoder.decoder_hidden, model.decoder.decoder_cell]) +model.decoder._init_states(oo_en, mask=None) # pylint: disable=protected-access +output, cell_state = model.decoder.decoder_rnn(input, [model.decoder.decoder_hidden, model.decoder.decoder_cell]) states = model_tf.decoder.build_decoder_initial_states(1, 512, 128) -output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf, - states[3], - training=False) +output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf, states[3], training=False) assert abs(input - input_tf).mean() < 1e-5 assert compare_torch_tf(output, output_tf).mean() < 1e-5 @@ -198,12 +177,10 @@ assert compare_torch_tf(output.transpose(1, 2), output_tf) < 1e-4 outputs_torch = model.inference(input_ids) outputs_tf = model_tf(tf.convert_to_tensor(input_ids.numpy())) print(abs(outputs_torch[0].numpy()[:, 0] - outputs_tf[0].numpy()[:, 0]).mean()) -assert compare_torch_tf(outputs_torch[2][:, 50, :], - outputs_tf[2][:, 50, :]) < 1e-5 +assert compare_torch_tf(outputs_torch[2][:, 50, :], outputs_tf[2][:, 50, :]) < 1e-5 assert compare_torch_tf(outputs_torch[0], outputs_tf[0]) < 1e-4 # %% # save tf model -save_checkpoint(model_tf, None, checkpoint['step'], checkpoint['epoch'], - checkpoint['r'], args.output_path) -print(' > Model conversion is successfully completed :).') +save_checkpoint(model_tf, None, checkpoint["step"], checkpoint["epoch"], checkpoint["r"], args.output_path) +print(" > Model conversion is successfully completed :).") diff --git a/TTS/bin/distribute.py b/TTS/bin/distribute.py index 390bd738..6b1c6fd6 100644 --- a/TTS/bin/distribute.py +++ b/TTS/bin/distribute.py @@ -15,26 +15,19 @@ def main(): Call train.py as a new process and pass command arguments """ parser = argparse.ArgumentParser() + parser.add_argument("--script", type=str, help="Target training script to distibute.") parser.add_argument( - '--script', - type=str, - help='Target training script to distibute.') - parser.add_argument( - '--continue_path', + "--continue_path", type=str, help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.', - default='', - required='--config_path' not in sys.argv) + default="", + required="--config_path" not in sys.argv, + ) parser.add_argument( - '--restore_path', - type=str, - help='Model file to be restored. Use to finetune a model.', - default='') + "--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default="" + ) parser.add_argument( - '--config_path', - type=str, - help='Path to config file for training.', - required='--continue_path' not in sys.argv + "--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in sys.argv ) args = parser.parse_args() @@ -44,20 +37,20 @@ def main(): # set arguments for train.py folder_path = pathlib.Path(__file__).parent.absolute() command = [os.path.join(folder_path, args.script)] - command.append('--continue_path={}'.format(args.continue_path)) - command.append('--restore_path={}'.format(args.restore_path)) - command.append('--config_path={}'.format(args.config_path)) - command.append('--group_id=group_{}'.format(group_id)) - command.append('') + command.append("--continue_path={}".format(args.continue_path)) + command.append("--restore_path={}".format(args.restore_path)) + command.append("--config_path={}".format(args.config_path)) + command.append("--group_id=group_{}".format(group_id)) + command.append("") # run processes processes = [] for i in range(num_gpus): my_env = os.environ.copy() my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i) - command[-1] = '--rank={}'.format(i) - stdout = None if i == 0 else open(os.devnull, 'w') - p = subprocess.Popen(['python3'] + command, stdout=stdout, env=my_env) + command[-1] = "--rank={}".format(i) + stdout = None if i == 0 else open(os.devnull, "w") + p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env) processes.append(p) print(command) @@ -65,5 +58,5 @@ def main(): p.wait() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/TTS/bin/find_unique_chars.py b/TTS/bin/find_unique_chars.py index f9b6827b..b7056e01 100644 --- a/TTS/bin/find_unique_chars.py +++ b/TTS/bin/find_unique_chars.py @@ -7,30 +7,24 @@ from TTS.tts.datasets.preprocess import get_preprocessor_by_name def main(): - # pylint: disable=bad-continuation - parser = argparse.ArgumentParser(description='''Find all the unique characters or phonemes in a dataset.\n\n''' - - '''Target dataset must be defined in TTS.tts.datasets.preprocess\n\n'''\ - ''' + # pylint: disable=bad-option-value + parser = argparse.ArgumentParser( + description="""Find all the unique characters or phonemes in a dataset.\n\n""" + """Target dataset must be defined in TTS.tts.datasets.preprocess\n\n""" + """ Example runs: python TTS/bin/find_unique_chars.py --dataset ljspeech --meta_file /path/to/LJSpeech/metadata.csv - ''', formatter_class=RawTextHelpFormatter) - - parser.add_argument( - '--dataset', - type=str, - default='', - help='One of the target dataset names in TTS.tts.datasets.preprocess.' - ) - - parser.add_argument( - '--meta_file', - type=str, - default=None, - help='Path to the transcriptions file of the dataset.' + """, + formatter_class=RawTextHelpFormatter, ) + parser.add_argument( + "--dataset", type=str, default="", help="One of the target dataset names in TTS.tts.datasets.preprocess." + ) + + parser.add_argument("--meta_file", type=str, default=None, help="Path to the transcriptions file of the dataset.") + args = parser.parse_args() preprocessor = get_preprocessor_by_name(args.dataset) diff --git a/TTS/bin/resample.py b/TTS/bin/resample.py index 080e2bad..7d358d4d 100644 --- a/TTS/bin/resample.py +++ b/TTS/bin/resample.py @@ -7,15 +7,17 @@ from argparse import RawTextHelpFormatter from multiprocessing import Pool from tqdm import tqdm + def resample_file(func_args): filename, output_sr = func_args y, sr = librosa.load(filename, sr=output_sr) librosa.output.write_wav(filename, y, sr) -if __name__ == '__main__': + +if __name__ == "__main__": parser = argparse.ArgumentParser( - description='''Resample a folder recusively with librosa + description="""Resample a folder recusively with librosa Can be used in place or create a copy of the folder as an output.\n\n Example run: python TTS/bin/resample.py @@ -23,46 +25,52 @@ if __name__ == '__main__': --output_sr 22050 --output_dir /root/resampled_LJSpeech-1.1/ --n_jobs 24 - ''', - formatter_class=RawTextHelpFormatter) + """, + formatter_class=RawTextHelpFormatter, + ) - parser.add_argument('--input_dir', - type=str, - default=None, - required=True, - help='Path of the folder containing the audio files to resample') + parser.add_argument( + "--input_dir", + type=str, + default=None, + required=True, + help="Path of the folder containing the audio files to resample", + ) - parser.add_argument('--output_sr', - type=int, - default=22050, - required=False, - help='Samlple rate to which the audio files should be resampled') + parser.add_argument( + "--output_sr", + type=int, + default=22050, + required=False, + help="Samlple rate to which the audio files should be resampled", + ) - parser.add_argument('--output_dir', - type=str, - default=None, - required=False, - help='Path of the destination folder. If not defined, the operation is done in place') + parser.add_argument( + "--output_dir", + type=str, + default=None, + required=False, + help="Path of the destination folder. If not defined, the operation is done in place", + ) - parser.add_argument('--n_jobs', - type=int, - default=None, - help='Number of threads to use, by default it uses all cores') + parser.add_argument( + "--n_jobs", type=int, default=None, help="Number of threads to use, by default it uses all cores" + ) args = parser.parse_args() if args.output_dir: - print('Recursively copying the input folder...') + print("Recursively copying the input folder...") copy_tree(args.input_dir, args.output_dir) args.input_dir = args.output_dir - print('Resampling the audio files...') - audio_files = glob.glob(os.path.join(args.input_dir, '**/*.wav'), recursive=True) - print(f'Found {len(audio_files)} files...') - audio_files = list(zip(audio_files, len(audio_files)*[args.output_sr])) + print("Resampling the audio files...") + audio_files = glob.glob(os.path.join(args.input_dir, "**/*.wav"), recursive=True) + print(f"Found {len(audio_files)} files...") + audio_files = list(zip(audio_files, len(audio_files) * [args.output_sr])) with Pool(processes=args.n_jobs) as p: with tqdm(total=len(audio_files)) as pbar: for i, _ in enumerate(p.imap_unordered(resample_file, audio_files)): pbar.update() - print('Done !') + print("Done !") diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 8b96d945..aca245bb 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -4,6 +4,7 @@ import argparse import sys from argparse import RawTextHelpFormatter + # pylint: disable=redefined-outer-name, unused-argument from pathlib import Path @@ -14,22 +15,20 @@ from TTS.utils.synthesizer import Synthesizer def str2bool(v): if isinstance(v, bool): return v - if v.lower() in ('yes', 'true', 't', 'y', '1'): + if v.lower() in ("yes", "true", "t", "y", "1"): return True - if v.lower() in ('no', 'false', 'f', 'n', '0'): + if v.lower() in ("no", "false", "f", "n", "0"): return False - raise argparse.ArgumentTypeError('Boolean value expected.') + raise argparse.ArgumentTypeError("Boolean value expected.") def main(): - # pylint: disable=bad-continuation - parser = argparse.ArgumentParser(description='''Synthesize speech on command line.\n\n''' - - '''You can either use your trained model or choose a model from the provided list.\n\n'''\ - - '''If you don't specify any models, then it uses LJSpeech based English models\n\n'''\ - - ''' + # pylint: disable=bad-option-value + parser = argparse.ArgumentParser( + description="""Synthesize speech on command line.\n\n""" + """You can either use your trained model or choose a model from the provided list.\n\n""" + """If you don't specify any models, then it uses LJSpeech based English models\n\n""" + """ Example runs: # list provided models @@ -51,106 +50,80 @@ def main(): ./TTS/bin/synthesize.py --text "Text for TTS" --model_path path/to/config.json --config_path path/to/model.pth.tar --out_path output/path/speech.wav --vocoder_path path/to/vocoder.pth.tar --vocoder_config_path path/to/vocoder_config.json - ''', - formatter_class=RawTextHelpFormatter) + """, + formatter_class=RawTextHelpFormatter, + ) parser.add_argument( - '--list_models', + "--list_models", type=str2bool, - nargs='?', + nargs="?", const=True, default=False, - help='list available pre-trained tts and vocoder models.' - ) - parser.add_argument( - '--text', - type=str, - default=None, - help='Text to generate speech.' - ) + help="list available pre-trained tts and vocoder models.", + ) + parser.add_argument("--text", type=str, default=None, help="Text to generate speech.") # Args for running pre-trained TTS models. parser.add_argument( - '--model_name', + "--model_name", type=str, default="tts_models/en/ljspeech/speedy-speech-wn", - help= - 'Name of one of the pre-trained tts models in format //' + help="Name of one of the pre-trained tts models in format //", ) parser.add_argument( - '--vocoder_name', + "--vocoder_name", type=str, default=None, - help= - 'Name of one of the pre-trained vocoder models in format //' + help="Name of one of the pre-trained vocoder models in format //", ) # Args for running custom models + parser.add_argument("--config_path", default=None, type=str, help="Path to model config file.") parser.add_argument( - '--config_path', - default=None, - type=str, - help='Path to model config file.' - ) - parser.add_argument( - '--model_path', + "--model_path", type=str, default=None, - help='Path to model file.', + help="Path to model file.", ) parser.add_argument( - '--out_path', + "--out_path", type=str, - default='tts_output.wav', - help='Output wav file path.', + default="tts_output.wav", + help="Output wav file path.", ) + parser.add_argument("--use_cuda", type=bool, help="Run model on CUDA.", default=False) parser.add_argument( - '--use_cuda', - type=bool, - help='Run model on CUDA.', - default=False - ) - parser.add_argument( - '--vocoder_path', + "--vocoder_path", type=str, - help= - 'Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).', + help="Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).", default=None, ) - parser.add_argument( - '--vocoder_config_path', - type=str, - help='Path to vocoder model config file.', - default=None) + parser.add_argument("--vocoder_config_path", type=str, help="Path to vocoder model config file.", default=None) # args for multi-speaker synthesis + parser.add_argument("--speakers_json", type=str, help="JSON file for multi-speaker model.", default=None) parser.add_argument( - '--speakers_json', - type=str, - help="JSON file for multi-speaker model.", - default=None) - parser.add_argument( - '--speaker_idx', + "--speaker_idx", type=str, help="if the tts model is trained with x-vectors, then speaker_idx is a file present in speakers.json else speaker_idx is the speaker id corresponding to a speaker in the speaker embedding layer.", - default=None) - parser.add_argument( - '--gst_style', - help="Wav path file for GST stylereference.", - default=None) + default=None, + ) + parser.add_argument("--gst_style", help="Wav path file for GST stylereference.", default=None) # aux args parser.add_argument( - '--save_spectogram', + "--save_spectogram", type=bool, help="If true save raw spectogram for further (vocoder) processing in out_path.", - default=False) + default=False, + ) args = parser.parse_args() # print the description if either text or list_models is not set if args.text is None and not args.list_models: - parser.parse_args(['-h']) + parser.parse_args(["-h"]) # load model manager path = Path(__file__).parent / "../.models.json" @@ -169,7 +142,7 @@ def main(): # CASE2: load pre-trained models if args.model_name is not None: model_path, config_path, model_item = manager.download_model(args.model_name) - args.vocoder_name = model_item['default_vocoder'] if args.vocoder_name is None else args.vocoder_name + args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name if args.vocoder_name is not None: vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name) diff --git a/TTS/bin/train_align_tts.py b/TTS/bin/train_align_tts.py index 1b3e7d52..16940f1e 100644 --- a/TTS/bin/train_align_tts.py +++ b/TTS/bin/train_align_tts.py @@ -25,12 +25,11 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.arguments import parse_arguments, process_args from TTS.utils.audio import AudioProcessor from TTS.utils.distribute import init_distributed, reduce_tensor -from TTS.utils.generic_utils import (KeepAverage, count_parameters, - remove_experiment_folder, set_init_dict) +from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.radam import RAdam from TTS.utils.training import NoamLR, setup_torch_training_env -if __name__ == '__main__': +if __name__ == "__main__": use_cuda, num_gpus = setup_torch_training_env(True, False) # torch.autograd.set_detect_anomaly(True) @@ -44,10 +43,9 @@ if __name__ == '__main__': compute_linear_spec=False, meta_data=meta_data_eval if is_val else meta_data_train, ap=ap, - tp=c.characters if 'characters' in c.keys() else None, - add_blank=c['add_blank'] if 'add_blank' in c.keys() else False, - batch_group_size=0 if is_val else c.batch_group_size * - c.batch_size, + tp=c.characters if "characters" in c.keys() else None, + add_blank=c["add_blank"] if "add_blank" in c.keys() else False, + batch_group_size=0 if is_val else c.batch_group_size * c.batch_size, min_seq_len=c.min_seq_len, max_seq_len=c.max_seq_len, phoneme_cache_path=c.phoneme_cache_path, @@ -56,8 +54,10 @@ if __name__ == '__main__': enable_eos_bos=c.enable_eos_bos_chars, use_noise_augment=not is_val, verbose=verbose, - speaker_mapping=speaker_mapping if c.use_speaker_embedding - and c.use_external_speaker_embedding_file else None) + speaker_mapping=speaker_mapping + if c.use_speaker_embedding and c.use_external_speaker_embedding_file + else None, + ) if c.use_phonemes and c.compute_input_seq_cache: # precompute phonemes to have a better estimate of sequence lengths. @@ -72,9 +72,9 @@ if __name__ == '__main__': collate_fn=dataset.collate_fn, drop_last=False, sampler=sampler, - num_workers=c.num_val_loader_workers - if is_val else c.num_loader_workers, - pin_memory=False) + num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, + pin_memory=False, + ) return loader def format_data(data): @@ -94,10 +94,7 @@ if __name__ == '__main__': speaker_c = data[8] else: # return speaker_id to be used by an embedding layer - speaker_c = [ - speaker_mapping[speaker_name] - for speaker_name in speaker_names - ] + speaker_c = [speaker_mapping[speaker_name] for speaker_name in speaker_names] speaker_c = torch.LongTensor(speaker_c) else: speaker_c = None @@ -109,18 +106,15 @@ if __name__ == '__main__': mel_lengths = mel_lengths.cuda(non_blocking=True) if speaker_c is not None: speaker_c = speaker_c.cuda(non_blocking=True) - return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\ - avg_text_length, avg_spec_length, item_idx + return text_input, text_lengths, mel_input, mel_lengths, speaker_c, avg_text_length, avg_spec_length, item_idx - def train(data_loader, model, criterion, optimizer, scheduler, ap, - global_step, epoch, training_phase): + def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch, training_phase): model.train() epoch_time = 0 keep_avg = KeepAverage() if use_cuda: - batch_n_iter = int( - len(data_loader.dataset) / (c.batch_size * num_gpus)) + batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() @@ -130,8 +124,16 @@ if __name__ == '__main__': start_time = time.time() # format data - text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\ - avg_text_length, avg_spec_length, _ = format_data(data) + ( + text_input, + text_lengths, + mel_targets, + mel_lengths, + speaker_c, + avg_text_length, + avg_spec_length, + _, + ) = format_data(data) loader_time = time.time() - end_time @@ -141,36 +143,32 @@ if __name__ == '__main__': # forward pass model with torch.cuda.amp.autocast(enabled=c.mixed_precision): decoder_output, dur_output, dur_mas_output, alignments, _, _, logp = model.forward( - text_input, - text_lengths, - mel_targets, - mel_lengths, - g=speaker_c, - phase=training_phase) + text_input, text_lengths, mel_targets, mel_lengths, g=speaker_c, phase=training_phase + ) # compute loss - loss_dict = criterion(logp, - decoder_output, - mel_targets, - mel_lengths, - dur_output, - dur_mas_output, - text_lengths, - global_step, - phase=training_phase) + loss_dict = criterion( + logp, + decoder_output, + mel_targets, + mel_lengths, + dur_output, + dur_mas_output, + text_lengths, + global_step, + phase=training_phase, + ) # backward pass with loss scaling if c.mixed_precision: - scaler.scale(loss_dict['loss']).backward() + scaler.scale(loss_dict["loss"]).backward() scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_( - model.parameters(), c.grad_clip) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) scaler.step(optimizer) scaler.update() else: - loss_dict['loss'].backward() - grad_norm = torch.nn.utils.clip_grad_norm_( - model.parameters(), c.grad_clip) + loss_dict["loss"].backward() + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) optimizer.step() # setup lr @@ -178,25 +176,21 @@ if __name__ == '__main__': scheduler.step() # current_lr - current_lr = optimizer.param_groups[0]['lr'] + current_lr = optimizer.param_groups[0]["lr"] # compute alignment error (the lower the better ) align_error = 1 - alignment_diagonal_score(alignments, binary=True) - loss_dict['align_error'] = align_error + loss_dict["align_error"] = align_error step_time = time.time() - start_time epoch_time += step_time # aggregate losses from processes if num_gpus > 1: - loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, - num_gpus) - loss_dict['loss_ssim'] = reduce_tensor( - loss_dict['loss_ssim'].data, num_gpus) - loss_dict['loss_dur'] = reduce_tensor( - loss_dict['loss_dur'].data, num_gpus) - loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data, - num_gpus) + loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus) + loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus) + loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus) + loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) # detach loss values loss_dict_new = dict() @@ -210,48 +204,43 @@ if __name__ == '__main__': # update avg stats update_train_values = dict() for key, value in loss_dict.items(): - update_train_values['avg_' + key] = value - update_train_values['avg_loader_time'] = loader_time - update_train_values['avg_step_time'] = step_time + update_train_values["avg_" + key] = value + update_train_values["avg_loader_time"] = loader_time + update_train_values["avg_step_time"] = step_time keep_avg.update_values(update_train_values) # print training progress if global_step % c.print_step == 0: log_dict = { - "avg_spec_length": [avg_spec_length, - 1], # value, precision + "avg_spec_length": [avg_spec_length, 1], # value, precision "avg_text_length": [avg_text_length, 1], "step_time": [step_time, 4], "loader_time": [loader_time, 2], "current_lr": current_lr, } - c_logger.print_train_step(batch_n_iter, num_iter, global_step, - log_dict, loss_dict, - keep_avg.avg_values) + c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) if args.rank == 0: # Plot Training Iter Stats # reduce TB load if global_step % c.tb_plot_step == 0: - iter_stats = { - "lr": current_lr, - "grad_norm": grad_norm, - "step_time": step_time - } + iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time} iter_stats.update(loss_dict) tb_logger.tb_train_iter_stats(global_step, iter_stats) if global_step % c.save_step == 0: if c.checkpoint: # save model - save_checkpoint(model, - optimizer, - global_step, - epoch, - 1, - OUT_PATH, - model_characters, - model_loss=loss_dict['loss']) + save_checkpoint( + model, + optimizer, + global_step, + epoch, + 1, + OUT_PATH, + model_characters, + model_loss=loss_dict["loss"], + ) # wait all kernels to be completed torch.cuda.synchronize() @@ -259,8 +248,7 @@ if __name__ == '__main__': # Diagnostic visualizations if decoder_output is not None: idx = np.random.randint(mel_targets.shape[0]) - pred_spec = decoder_output[idx].detach().data.cpu( - ).numpy().T + pred_spec = decoder_output[idx].detach().data.cpu().numpy().T gt_spec = mel_targets[idx].data.cpu().numpy().T align_img = alignments[idx].data.cpu() @@ -274,14 +262,11 @@ if __name__ == '__main__': # Sample audio train_audio = ap.inv_melspectrogram(pred_spec.T) - tb_logger.tb_train_audios(global_step, - {'TrainAudio': train_audio}, - c.audio["sample_rate"]) + tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, c.audio["sample_rate"]) end_time = time.time() # print epoch stats - c_logger.print_train_epoch_end(global_step, epoch, epoch_time, - keep_avg) + c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg) # Plot Epoch Stats if args.rank == 0: @@ -293,8 +278,7 @@ if __name__ == '__main__': return keep_avg.avg_values, global_step @torch.no_grad() - def evaluate(data_loader, model, criterion, ap, global_step, epoch, - training_phase): + def evaluate(data_loader, model, criterion, ap, global_step, epoch, training_phase): model.eval() epoch_time = 0 keep_avg = KeepAverage() @@ -304,50 +288,41 @@ if __name__ == '__main__': start_time = time.time() # format data - text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\ - _, _, _ = format_data(data) + text_input, text_lengths, mel_targets, mel_lengths, speaker_c, _, _, _ = format_data(data) # forward pass model with torch.cuda.amp.autocast(enabled=c.mixed_precision): decoder_output, dur_output, dur_mas_output, alignments, _, _, logp = model.forward( - text_input, - text_lengths, - mel_targets, - mel_lengths, - g=speaker_c, - phase=training_phase) + text_input, text_lengths, mel_targets, mel_lengths, g=speaker_c, phase=training_phase + ) # compute loss - loss_dict = criterion(logp, - decoder_output, - mel_targets, - mel_lengths, - dur_output, - dur_mas_output, - text_lengths, - global_step, - phase=training_phase) - + loss_dict = criterion( + logp, + decoder_output, + mel_targets, + mel_lengths, + dur_output, + dur_mas_output, + text_lengths, + global_step, + phase=training_phase, + ) # step time step_time = time.time() - start_time epoch_time += step_time # compute alignment score - align_error = 1 - alignment_diagonal_score(alignments, - binary=True) - loss_dict['align_error'] = align_error + align_error = 1 - alignment_diagonal_score(alignments, binary=True) + loss_dict["align_error"] = align_error # aggregate losses from processes if num_gpus > 1: - loss_dict['loss_l1'] = reduce_tensor( - loss_dict['loss_l1'].data, num_gpus) - loss_dict['loss_ssim'] = reduce_tensor( - loss_dict['loss_ssim'].data, num_gpus) - loss_dict['loss_dur'] = reduce_tensor( - loss_dict['loss_dur'].data, num_gpus) - loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data, - num_gpus) + loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus) + loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus) + loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus) + loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) # detach loss values loss_dict_new = dict() @@ -361,12 +336,11 @@ if __name__ == '__main__': # update avg stats update_train_values = dict() for key, value in loss_dict.items(): - update_train_values['avg_' + key] = value + update_train_values["avg_" + key] = value keep_avg.update_values(update_train_values) if c.print_eval: - c_logger.print_eval_step(num_iter, loss_dict, - keep_avg.avg_values) + c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) if args.rank == 0: # Diagnostic visualizations @@ -376,19 +350,14 @@ if __name__ == '__main__': align_img = alignments[idx].data.cpu() eval_figures = { - "prediction": plot_spectrogram(pred_spec, - ap, - output_fig=False), - "ground_truth": plot_spectrogram(gt_spec, - ap, - output_fig=False), - "alignment": plot_alignment(align_img, output_fig=False) + "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), + "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), + "alignment": plot_alignment(align_img, output_fig=False), } # Sample audio eval_audio = ap.inv_melspectrogram(pred_spec.T) - tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, - c.audio["sample_rate"]) + tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"]) # Plot Validation Stats tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) @@ -401,7 +370,7 @@ if __name__ == '__main__': "Be a voice, not an echo.", "I'm sorry Dave. I'm afraid I can't do that.", "This cake is great. It's so delicious and moist.", - "Prior to November 22, 1963." + "Prior to November 22, 1963.", ] else: with open(c.test_sentences_file, "r") as f: @@ -413,9 +382,9 @@ if __name__ == '__main__': print(" | > Synthesizing test sentences") if c.use_speaker_embedding: if c.use_external_speaker_embedding_file: - speaker_embedding = speaker_mapping[list( - speaker_mapping.keys())[randrange( - len(speaker_mapping) - 1)]]['embedding'] + speaker_embedding = speaker_mapping[ + list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)] + ]["embedding"] speaker_id = None else: speaker_id = 0 @@ -437,25 +406,22 @@ if __name__ == '__main__': speaker_embedding=speaker_embedding, style_wav=style_wav, truncated=False, - enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument + enable_eos_bos_chars=c.enable_eos_bos_chars, # pylint: disable=unused-argument use_griffin_lim=True, - do_trim_silence=False) + do_trim_silence=False, + ) file_path = os.path.join(AUDIO_PATH, str(global_step)) os.makedirs(file_path, exist_ok=True) - file_path = os.path.join(file_path, - "TestSentence_{}.wav".format(idx)) + file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx)) ap.save_wav(wav, file_path) - test_audios['{}-audio'.format(idx)] = wav - test_figures['{}-prediction'.format( - idx)] = plot_spectrogram(postnet_output, ap) - test_figures['{}-alignment'.format(idx)] = plot_alignment( - alignment) - except: #pylint: disable=bare-except + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap) + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment) + except: # pylint: disable=bare-except print(" !! Error creating Test Sentence -", idx) traceback.print_exc() - tb_logger.tb_test_audios(global_step, test_audios, - c.audio['sample_rate']) + tb_logger.tb_test_audios(global_step, test_audios, c.audio["sample_rate"]) tb_logger.tb_test_figures(global_step, test_figures) return keep_avg.avg_values @@ -464,69 +430,55 @@ if __name__ == '__main__': global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping # Audio processor ap = AudioProcessor(**c.audio) - if 'characters' in c.keys(): + if "characters" in c.keys(): symbols, phonemes = make_symbols(**c.characters) # DISTRUBUTED if num_gpus > 1: - init_distributed(args.rank, num_gpus, args.group_id, - c.distributed["backend"], c.distributed["url"]) + init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) # set model characters model_characters = phonemes if c.use_phonemes else symbols num_chars = len(model_characters) # load data instances - meta_data_train, meta_data_eval = load_meta_data(c.datasets, - eval_split=True) + meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True) # set the portion of the data used for training if set in config.json - if 'train_portion' in c.keys(): - meta_data_train = meta_data_train[:int( - len(meta_data_train) * c.train_portion)] - if 'eval_portion' in c.keys(): - meta_data_eval = meta_data_eval[:int( - len(meta_data_eval) * c.eval_portion)] + if "train_portion" in c.keys(): + meta_data_train = meta_data_train[: int(len(meta_data_train) * c.train_portion)] + if "eval_portion" in c.keys(): + meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * c.eval_portion)] # parse speakers - num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers( - c, args, meta_data_train, OUT_PATH) + num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH) # setup model - model = setup_model(num_chars, - num_speakers, - c, - speaker_embedding_dim=speaker_embedding_dim) - optimizer = RAdam(model.parameters(), - lr=c.lr, - weight_decay=0, - betas=(0.9, 0.98), - eps=1e-9) + model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim) + optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9) criterion = AlignTTSLoss(c) if args.restore_path: - print( - f" > Restoring from {os.path.basename(args.restore_path)} ...") - checkpoint = torch.load(args.restore_path, map_location='cpu') + print(f" > Restoring from {os.path.basename(args.restore_path)} ...") + checkpoint = torch.load(args.restore_path, map_location="cpu") try: # TODO: fix optimizer init, model.cuda() needs to be called before # optimizer restore - optimizer.load_state_dict(checkpoint['optimizer']) + optimizer.load_state_dict(checkpoint["optimizer"]) if c.reinit_layers: raise RuntimeError - model.load_state_dict(checkpoint['model']) - except: #pylint: disable=bare-except + model.load_state_dict(checkpoint["model"]) + except: # pylint: disable=bare-except print(" > Partial model initialization.") model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint['model'], c) + model_dict = set_init_dict(model_dict, checkpoint["model"], c) model.load_state_dict(model_dict) del model_dict for group in optimizer.param_groups: - group['initial_lr'] = c.lr - print(" > Model restored from step %d" % checkpoint['step'], - flush=True) - args.restore_step = checkpoint['step'] + group["initial_lr"] = c.lr + print(" > Model restored from step %d" % checkpoint["step"], flush=True) + args.restore_step = checkpoint["step"] else: args.restore_step = 0 @@ -539,9 +491,7 @@ if __name__ == '__main__': model = DDP_th(model, device_ids=[args.rank]) if c.noam_schedule: - scheduler = NoamLR(optimizer, - warmup_steps=c.warmup_steps, - last_epoch=args.restore_step - 1) + scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1) else: scheduler = None @@ -549,16 +499,14 @@ if __name__ == '__main__': print("\n > Model has {} parameters".format(num_params), flush=True) if args.restore_step == 0 or not args.best_path: - best_loss = float('inf') + best_loss = float("inf") print(" > Starting with inf best loss.") else: - print(" > Restoring best loss from " - f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, - map_location='cpu')['model_loss'] + print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") + best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] print(f" > Starting with loaded last best loss {best_loss}.") - keep_all_best = c.get('keep_all_best', False) - keep_after = c.get('keep_after', 10000) # void if keep_all_best False + keep_all_best = c.get("keep_all_best", False) + keep_after = c.get("keep_after", 10000) # void if keep_all_best False # define dataloaders train_loader = setup_loader(ap, 1, is_val=False, verbose=True) @@ -573,9 +521,9 @@ if __name__ == '__main__': if not True in vals: phase = 0 else: - phase = len(c.phase_start_steps) - [ - i < global_step for i in c.phase_start_steps - ][::-1].index(True) - 1 + phase = ( + len(c.phase_start_steps) - [i < global_step for i in c.phase_start_steps][::-1].index(True) - 1 + ) else: phase = None return phase @@ -584,32 +532,30 @@ if __name__ == '__main__': cur_phase = set_phase() print(f"\n > Current AlignTTS phase: {cur_phase}") c_logger.print_epoch_start(epoch, c.epochs) - train_avg_loss_dict, global_step = train(train_loader, model, - criterion, optimizer, - scheduler, ap, - global_step, epoch, - cur_phase) - eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, - global_step, epoch, cur_phase) + train_avg_loss_dict, global_step = train( + train_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch, cur_phase + ) + eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch, cur_phase) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) - target_loss = train_avg_loss_dict['avg_loss'] + target_loss = train_avg_loss_dict["avg_loss"] if c.run_eval: - target_loss = eval_avg_loss_dict['avg_loss'] - best_loss = save_best_model(target_loss, - best_loss, - model, - optimizer, - global_step, - epoch, - 1, - OUT_PATH, - model_characters, - keep_all_best=keep_all_best, - keep_after=keep_after) + target_loss = eval_avg_loss_dict["avg_loss"] + best_loss = save_best_model( + target_loss, + best_loss, + model, + optimizer, + global_step, + epoch, + 1, + OUT_PATH, + model_characters, + keep_all_best=keep_all_best, + keep_after=keep_after, + ) args = parse_arguments(sys.argv) - c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( - args, model_class='tts') + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts") try: main(args) diff --git a/TTS/bin/train_encoder.py b/TTS/bin/train_encoder.py index 12fba6e1..a2d917ac 100644 --- a/TTS/bin/train_encoder.py +++ b/TTS/bin/train_encoder.py @@ -12,14 +12,17 @@ from torch.utils.data import DataLoader from TTS.speaker_encoder.dataset import MyDataset from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss from TTS.speaker_encoder.model import SpeakerEncoder -from TTS.speaker_encoder.utils.generic_utils import \ - check_config_speaker_encoder, save_best_model +from TTS.speaker_encoder.utils.generic_utils import check_config_speaker_encoder, save_best_model from TTS.speaker_encoder.utils.visual import plot_embeddings from TTS.tts.datasets.preprocess import load_meta_data from TTS.utils.audio import AudioProcessor -from TTS.utils.generic_utils import (count_parameters, - create_experiment_folder, get_git_branch, - remove_experiment_folder, set_init_dict) +from TTS.utils.generic_utils import ( + count_parameters, + create_experiment_folder, + get_git_branch, + remove_experiment_folder, + set_init_dict, +) from TTS.utils.io import copy_model_files, load_config from TTS.utils.radam import RAdam from TTS.utils.tensorboard_logger import TensorboardLogger @@ -34,28 +37,30 @@ print(" > Using CUDA: ", use_cuda) print(" > Number of GPUs: ", num_gpus) -def setup_loader(ap: AudioProcessor, - is_val: bool = False, - verbose: bool = False): +def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False): if is_val: loader = None else: - dataset = MyDataset(ap, - meta_data_eval if is_val else meta_data_train, - voice_len=1.6, - num_utter_per_speaker=c.num_utters_per_speaker, - num_speakers_in_batch=c.num_speakers_in_batch, - skip_speakers=False, - storage_size=c.storage["storage_size"], - sample_from_storage_p=c.storage["sample_from_storage_p"], - additive_noise=c.storage["additive_noise"], - verbose=verbose) + dataset = MyDataset( + ap, + meta_data_eval if is_val else meta_data_train, + voice_len=1.6, + num_utter_per_speaker=c.num_utters_per_speaker, + num_speakers_in_batch=c.num_speakers_in_batch, + skip_speakers=False, + storage_size=c.storage["storage_size"], + sample_from_storage_p=c.storage["sample_from_storage_p"], + additive_noise=c.storage["additive_noise"], + verbose=verbose, + ) # sampler = DistributedSampler(dataset) if num_gpus > 1 else None - loader = DataLoader(dataset, - batch_size=c.num_speakers_in_batch, - shuffle=False, - num_workers=c.num_loader_workers, - collate_fn=dataset.collate_fn) + loader = DataLoader( + dataset, + batch_size=c.num_speakers_in_batch, + shuffle=False, + num_workers=c.num_loader_workers, + collate_fn=dataset.collate_fn, + ) return loader @@ -63,7 +68,7 @@ def train(model, criterion, optimizer, scheduler, ap, global_step): data_loader = setup_loader(ap, is_val=False, verbose=True) model.train() epoch_time = 0 - best_loss = float('inf') + best_loss = float("inf") avg_loss = 0 avg_loader_time = 0 end_time = time.time() @@ -89,9 +94,7 @@ def train(model, criterion, optimizer, scheduler, ap, global_step): outputs = model(inputs) # loss computation - loss = criterion( - outputs.view(c.num_speakers_in_batch, - outputs.shape[0] // c.num_speakers_in_batch, -1)) + loss = criterion(outputs.view(c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1)) loss.backward() grad_norm, _ = check_update(model, c.grad_clip) optimizer.step() @@ -100,11 +103,13 @@ def train(model, criterion, optimizer, scheduler, ap, global_step): epoch_time += step_time # Averaged Loss and Averaged Loader Time - avg_loss = 0.01 * loss.item() \ - + 0.99 * avg_loss if avg_loss != 0 else loss.item() - avg_loader_time = 1/c.num_loader_workers * loader_time + \ - (c.num_loader_workers-1) / c.num_loader_workers * avg_loader_time if avg_loader_time != 0 else loader_time - current_lr = optimizer.param_groups[0]['lr'] + avg_loss = 0.01 * loss.item() + 0.99 * avg_loss if avg_loss != 0 else loss.item() + avg_loader_time = ( + 1 / c.num_loader_workers * loader_time + (c.num_loader_workers - 1) / c.num_loader_workers * avg_loader_time + if avg_loader_time != 0 + else loader_time + ) + current_lr = optimizer.param_groups[0]["lr"] if global_step % c.steps_plot_stats == 0: # Plot Training Epoch Stats @@ -113,13 +118,12 @@ def train(model, criterion, optimizer, scheduler, ap, global_step): "lr": current_lr, "grad_norm": grad_norm, "step_time": step_time, - "avg_loader_time": avg_loader_time + "avg_loader_time": avg_loader_time, } tb_logger.tb_train_epoch_stats(global_step, train_stats) figures = { # FIXME: not constant - "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), - 10), + "UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10), } tb_logger.tb_train_figures(global_step, figures) @@ -127,13 +131,13 @@ def train(model, criterion, optimizer, scheduler, ap, global_step): print( " | > Step:{} Loss:{:.5f} AvgLoss:{:.5f} GradNorm:{:.5f} " "StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format( - global_step, loss.item(), avg_loss, grad_norm, step_time, - loader_time, avg_loader_time, current_lr), - flush=True) + global_step, loss.item(), avg_loss, grad_norm, step_time, loader_time, avg_loader_time, current_lr + ), + flush=True, + ) # save best model - best_loss = save_best_model(model, optimizer, avg_loss, best_loss, - OUT_PATH, global_step) + best_loss = save_best_model(model, optimizer, avg_loss, best_loss, OUT_PATH, global_step) end_time = time.time() return avg_loss, global_step @@ -145,14 +149,16 @@ def main(args): # pylint: disable=redefined-outer-name global meta_data_eval ap = AudioProcessor(**c.audio) - model = SpeakerEncoder(input_dim=c.model['input_dim'], - proj_dim=c.model['proj_dim'], - lstm_dim=c.model['lstm_dim'], - num_lstm_layers=c.model['num_lstm_layers']) + model = SpeakerEncoder( + input_dim=c.model["input_dim"], + proj_dim=c.model["proj_dim"], + lstm_dim=c.model["lstm_dim"], + num_lstm_layers=c.model["num_lstm_layers"], + ) optimizer = RAdam(model.parameters(), lr=c.lr) if c.loss == "ge2e": - criterion = GE2ELoss(loss_method='softmax') + criterion = GE2ELoss(loss_method="softmax") elif c.loss == "angleproto": criterion = AngleProtoLoss() else: @@ -166,7 +172,7 @@ def main(args): # pylint: disable=redefined-outer-name # optimizer.load_state_dict(checkpoint['optimizer']) if c.reinit_layers: raise RuntimeError - model.load_state_dict(checkpoint['model']) + model.load_state_dict(checkpoint["model"]) except KeyError: print(" > Partial model initialization.") model_dict = model.state_dict() @@ -174,10 +180,9 @@ def main(args): # pylint: disable=redefined-outer-name model.load_state_dict(model_dict) del model_dict for group in optimizer.param_groups: - group['lr'] = c.lr - print(" > Model restored from step %d" % checkpoint['step'], - flush=True) - args.restore_step = checkpoint['step'] + group["lr"] = c.lr + print(" > Model restored from step %d" % checkpoint["step"], flush=True) + args.restore_step = checkpoint["step"] else: args.restore_step = 0 @@ -186,9 +191,7 @@ def main(args): # pylint: disable=redefined-outer-name criterion.cuda() if c.lr_decay: - scheduler = NoamLR(optimizer, - warmup_steps=c.warmup_steps, - last_epoch=args.restore_step - 1) + scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1) else: scheduler = None @@ -199,55 +202,39 @@ def main(args): # pylint: disable=redefined-outer-name meta_data_train, meta_data_eval = load_meta_data(c.datasets) global_step = args.restore_step - _, global_step = train(model, criterion, optimizer, scheduler, ap, - global_step) + _, global_step = train(model, criterion, optimizer, scheduler, ap, global_step) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - '--restore_path', - type=str, - help='Path to model outputs (checkpoint, tensorboard etc.).', - default=0) + "--restore_path", type=str, help="Path to model outputs (checkpoint, tensorboard etc.).", default=0 + ) parser.add_argument( - '--config_path', + "--config_path", type=str, required=True, - help='Path to config file for training.', + help="Path to config file for training.", ) - parser.add_argument('--debug', - type=bool, - default=True, - help='Do not verify commit integrity to run training.') - parser.add_argument( - '--data_path', - type=str, - default='', - help='Defines the data path. It overwrites config.json.') - parser.add_argument('--output_path', - type=str, - help='path for training outputs.', - default='') - parser.add_argument('--output_folder', - type=str, - default='', - help='folder name for training outputs.') + parser.add_argument("--debug", type=bool, default=True, help="Do not verify commit integrity to run training.") + parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.") + parser.add_argument("--output_path", type=str, help="path for training outputs.", default="") + parser.add_argument("--output_folder", type=str, default="", help="folder name for training outputs.") args = parser.parse_args() # setup output paths and read configs c = load_config(args.config_path) check_config_speaker_encoder(c) _ = os.path.dirname(os.path.realpath(__file__)) - if args.data_path != '': + if args.data_path != "": c.data_path = args.data_path - if args.output_path == '': + if args.output_path == "": OUT_PATH = os.path.join(_, c.output_path) else: OUT_PATH = args.output_path - if args.output_folder == '': + if args.output_folder == "": OUT_PATH = create_experiment_folder(OUT_PATH, c.run_name, args.debug) else: OUT_PATH = os.path.join(OUT_PATH, args.output_folder) @@ -259,7 +246,7 @@ if __name__ == '__main__': copy_model_files(c, args.config_path, OUT_PATH, new_fields) LOG_DIR = OUT_PATH - tb_logger = TensorboardLogger(LOG_DIR, model_name='Speaker_Encoder') + tb_logger = TensorboardLogger(LOG_DIR, model_name="Speaker_Encoder") try: main(args) diff --git a/TTS/bin/train_glow_tts.py b/TTS/bin/train_glow_tts.py index 117de531..01b62c14 100644 --- a/TTS/bin/train_glow_tts.py +++ b/TTS/bin/train_glow_tts.py @@ -8,6 +8,7 @@ import traceback from random import randrange import torch + # DISTRIBUTED from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.utils.data import DataLoader @@ -26,8 +27,7 @@ from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor from TTS.utils.distribute import init_distributed, reduce_tensor -from TTS.utils.generic_utils import (KeepAverage, count_parameters, - remove_experiment_folder, set_init_dict) +from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.radam import RAdam from TTS.utils.training import NoamLR, setup_torch_training_env @@ -44,19 +44,21 @@ def setup_loader(ap, r, is_val=False, verbose=False): compute_linear_spec=False, meta_data=meta_data_eval if is_val else meta_data_train, ap=ap, - tp=c.characters if 'characters' in c.keys() else None, - add_blank=c['add_blank'] if 'add_blank' in c.keys() else False, - batch_group_size=0 if is_val else c.batch_group_size * - c.batch_size, + tp=c.characters if "characters" in c.keys() else None, + add_blank=c["add_blank"] if "add_blank" in c.keys() else False, + batch_group_size=0 if is_val else c.batch_group_size * c.batch_size, min_seq_len=c.min_seq_len, max_seq_len=c.max_seq_len, phoneme_cache_path=c.phoneme_cache_path, use_phonemes=c.use_phonemes, phoneme_language=c.phoneme_language, enable_eos_bos=c.enable_eos_bos_chars, - use_noise_augment=c['use_noise_augment'] and not is_val, + use_noise_augment=c["use_noise_augment"] and not is_val, verbose=verbose, - speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None) + speaker_mapping=speaker_mapping + if c.use_speaker_embedding and c.use_external_speaker_embedding_file + else None, + ) if c.use_phonemes and c.compute_input_seq_cache: # precompute phonemes to have a better estimate of sequence lengths. @@ -71,9 +73,9 @@ def setup_loader(ap, r, is_val=False, verbose=False): collate_fn=dataset.collate_fn, drop_last=False, sampler=sampler, - num_workers=c.num_val_loader_workers - if is_val else c.num_loader_workers, - pin_memory=False) + num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, + pin_memory=False, + ) return loader @@ -95,9 +97,7 @@ def format_data(data): speaker_c = data[8] else: # return speaker_id to be used by an embedding layer - speaker_c = [ - speaker_mapping[speaker_name] for speaker_name in speaker_names - ] + speaker_c = [speaker_mapping[speaker_name] for speaker_name in speaker_names] speaker_c = torch.LongTensor(speaker_c) else: speaker_c = None @@ -112,13 +112,22 @@ def format_data(data): speaker_c = speaker_c.cuda(non_blocking=True) if attn_mask is not None: attn_mask = attn_mask.cuda(non_blocking=True) - return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\ - avg_text_length, avg_spec_length, attn_mask, item_idx + return ( + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_c, + avg_text_length, + avg_spec_length, + attn_mask, + item_idx, + ) def data_depended_init(data_loader, model): """Data depended initialization for activation normalization.""" - if hasattr(model, 'module'): + if hasattr(model, "module"): for f in model.module.decoder.flows: if getattr(f, "set_ddi", False): f.set_ddi(True) @@ -134,17 +143,15 @@ def data_depended_init(data_loader, model): for _, data in enumerate(data_loader): # format data - text_input, text_lengths, mel_input, mel_lengths, spekaer_embed,\ - _, _, attn_mask, _ = format_data(data) + text_input, text_lengths, mel_input, mel_lengths, spekaer_embed, _, _, attn_mask, _ = format_data(data) # forward pass model - _ = model.forward( - text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=spekaer_embed) + _ = model.forward(text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=spekaer_embed) if num_iter == c.data_dep_init_iter: break num_iter += 1 - if hasattr(model, 'module'): + if hasattr(model, "module"): for f in model.module.decoder.flows: if getattr(f, "set_ddi", False): f.set_ddi(False) @@ -155,15 +162,13 @@ def data_depended_init(data_loader, model): return model -def train(data_loader, model, criterion, optimizer, scheduler, - ap, global_step, epoch): +def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch): model.train() epoch_time = 0 keep_avg = KeepAverage() if use_cuda: - batch_n_iter = int( - len(data_loader.dataset) / (c.batch_size * num_gpus)) + batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() @@ -173,8 +178,17 @@ def train(data_loader, model, criterion, optimizer, scheduler, start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, speaker_c,\ - avg_text_length, avg_spec_length, attn_mask, _ = format_data(data) + ( + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_c, + avg_text_length, + avg_spec_length, + attn_mask, + _, + ) = format_data(data) loader_time = time.time() - end_time @@ -184,24 +198,22 @@ def train(data_loader, model, criterion, optimizer, scheduler, # forward pass model with torch.cuda.amp.autocast(enabled=c.mixed_precision): z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( - text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c) + text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c + ) # compute loss - loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, - o_dur_log, o_total_dur, text_lengths) + loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, o_dur_log, o_total_dur, text_lengths) # backward pass with loss scaling if c.mixed_precision: - scaler.scale(loss_dict['loss']).backward() + scaler.scale(loss_dict["loss"]).backward() scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), - c.grad_clip) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) scaler.step(optimizer) scaler.update() else: - loss_dict['loss'].backward() - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), - c.grad_clip) + loss_dict["loss"].backward() + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) optimizer.step() # setup lr @@ -209,20 +221,20 @@ def train(data_loader, model, criterion, optimizer, scheduler, scheduler.step() # current_lr - current_lr = optimizer.param_groups[0]['lr'] + current_lr = optimizer.param_groups[0]["lr"] # compute alignment error (the lower the better ) align_error = 1 - alignment_diagonal_score(alignments, binary=True) - loss_dict['align_error'] = align_error + loss_dict["align_error"] = align_error step_time = time.time() - start_time epoch_time += step_time # aggregate losses from processes if num_gpus > 1: - loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus) - loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus) - loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus) + loss_dict["log_mle"] = reduce_tensor(loss_dict["log_mle"].data, num_gpus) + loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus) + loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) # detach loss values loss_dict_new = dict() @@ -236,9 +248,9 @@ def train(data_loader, model, criterion, optimizer, scheduler, # update avg stats update_train_values = dict() for key, value in loss_dict.items(): - update_train_values['avg_' + key] = value - update_train_values['avg_loader_time'] = loader_time - update_train_values['avg_step_time'] = step_time + update_train_values["avg_" + key] = value + update_train_values["avg_loader_time"] = loader_time + update_train_values["avg_step_time"] = step_time keep_avg.update_values(update_train_values) # print training progress @@ -250,26 +262,29 @@ def train(data_loader, model, criterion, optimizer, scheduler, "loader_time": [loader_time, 2], "current_lr": current_lr, } - c_logger.print_train_step(batch_n_iter, num_iter, global_step, - log_dict, loss_dict, keep_avg.avg_values) + c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) if args.rank == 0: # Plot Training Iter Stats # reduce TB load if global_step % c.tb_plot_step == 0: - iter_stats = { - "lr": current_lr, - "grad_norm": grad_norm, - "step_time": step_time - } + iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time} iter_stats.update(loss_dict) tb_logger.tb_train_iter_stats(global_step, iter_stats) if global_step % c.save_step == 0: if c.checkpoint: # save model - save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters, - model_loss=loss_dict['loss']) + save_checkpoint( + model, + optimizer, + global_step, + epoch, + 1, + OUT_PATH, + model_characters, + model_loss=loss_dict["loss"], + ) # wait all kernels to be completed torch.cuda.synchronize() @@ -278,7 +293,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, # direct pass on model for spec predictions target_speaker = None if speaker_c is None else speaker_c[:1] - if hasattr(model, 'module'): + if hasattr(model, "module"): spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker) else: spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker) @@ -299,9 +314,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, # Sample audio train_audio = ap.inv_melspectrogram(const_spec.T) - tb_logger.tb_train_audios(global_step, - {'TrainAudio': train_audio}, - c.audio["sample_rate"]) + tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, c.audio["sample_rate"]) end_time = time.time() # print epoch stats @@ -328,16 +341,15 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, speaker_c,\ - _, _, attn_mask, _ = format_data(data) + text_input, text_lengths, mel_input, mel_lengths, speaker_c, _, _, attn_mask, _ = format_data(data) # forward pass model z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward( - text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c) + text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c + ) # compute loss - loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, - o_dur_log, o_total_dur, text_lengths) + loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, o_dur_log, o_total_dur, text_lengths) # step time step_time = time.time() - start_time @@ -345,13 +357,13 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): # compute alignment score align_error = 1 - alignment_diagonal_score(alignments) - loss_dict['align_error'] = align_error + loss_dict["align_error"] = align_error # aggregate losses from processes if num_gpus > 1: - loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus) - loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus) - loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus) + loss_dict["log_mle"] = reduce_tensor(loss_dict["log_mle"].data, num_gpus) + loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus) + loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) # detach loss values loss_dict_new = dict() @@ -365,7 +377,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): # update avg stats update_train_values = dict() for key, value in loss_dict.items(): - update_train_values['avg_' + key] = value + update_train_values["avg_" + key] = value keep_avg.update_values(update_train_values) if c.print_eval: @@ -375,7 +387,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): # Diagnostic visualizations # direct pass on model for spec predictions target_speaker = None if speaker_c is None else speaker_c[:1] - if hasattr(model, 'module'): + if hasattr(model, "module"): spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker) else: spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker) @@ -389,13 +401,12 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): eval_figures = { "prediction": plot_spectrogram(const_spec, ap), "ground_truth": plot_spectrogram(gt_spec, ap), - "alignment": plot_alignment(align_img) + "alignment": plot_alignment(align_img), } # Sample audio eval_audio = ap.inv_melspectrogram(const_spec.T) - tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, - c.audio["sample_rate"]) + tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"]) # Plot Validation Stats tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) @@ -408,7 +419,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): "Be a voice, not an echo.", "I'm sorry Dave. I'm afraid I can't do that.", "This cake is great. It's so delicious and moist.", - "Prior to November 22, 1963." + "Prior to November 22, 1963.", ] else: with open(c.test_sentences_file, "r") as f: @@ -420,7 +431,9 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): print(" | > Synthesizing test sentences") if c.use_speaker_embedding: if c.use_external_speaker_embedding_file: - speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding'] + speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]][ + "embedding" + ] speaker_id = None else: speaker_id = 0 @@ -442,25 +455,22 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): speaker_embedding=speaker_embedding, style_wav=style_wav, truncated=False, - enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument + enable_eos_bos_chars=c.enable_eos_bos_chars, # pylint: disable=unused-argument use_griffin_lim=True, - do_trim_silence=False) + do_trim_silence=False, + ) file_path = os.path.join(AUDIO_PATH, str(global_step)) os.makedirs(file_path, exist_ok=True) - file_path = os.path.join(file_path, - "TestSentence_{}.wav".format(idx)) + file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx)) ap.save_wav(wav, file_path) - test_audios['{}-audio'.format(idx)] = wav - test_figures['{}-prediction'.format(idx)] = plot_spectrogram( - postnet_output, ap) - test_figures['{}-alignment'.format(idx)] = plot_alignment( - alignment) - except: #pylint: disable=bare-except + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap) + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment) + except: # pylint: disable=bare-except print(" !! Error creating Test Sentence -", idx) traceback.print_exc() - tb_logger.tb_test_audios(global_step, test_audios, - c.audio['sample_rate']) + tb_logger.tb_test_audios(global_step, test_audios, c.audio["sample_rate"]) tb_logger.tb_test_figures(global_step, test_figures) return keep_avg.avg_values @@ -470,13 +480,12 @@ def main(args): # pylint: disable=redefined-outer-name global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping # Audio processor ap = AudioProcessor(**c.audio) - if 'characters' in c.keys(): + if "characters" in c.keys(): symbols, phonemes = make_symbols(**c.characters) # DISTRUBUTED if num_gpus > 1: - init_distributed(args.rank, num_gpus, args.group_id, - c.distributed["backend"], c.distributed["url"]) + init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) # set model characters model_characters = phonemes if c.use_phonemes else symbols @@ -486,10 +495,10 @@ def main(args): # pylint: disable=redefined-outer-name meta_data_train, meta_data_eval = load_meta_data(c.datasets) # set the portion of the data used for training - if 'train_portion' in c.keys(): - meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)] - if 'eval_portion' in c.keys(): - meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)] + if "train_portion" in c.keys(): + meta_data_train = meta_data_train[: int(len(meta_data_train) * c.train_portion)] + if "eval_portion" in c.keys(): + meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * c.eval_portion)] # parse speakers num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH) @@ -501,26 +510,25 @@ def main(args): # pylint: disable=redefined-outer-name if args.restore_path: print(f" > Restoring from {os.path.basename(args.restore_path)} ...") - checkpoint = torch.load(args.restore_path, map_location='cpu') + checkpoint = torch.load(args.restore_path, map_location="cpu") try: # TODO: fix optimizer init, model.cuda() needs to be called before # optimizer restore - optimizer.load_state_dict(checkpoint['optimizer']) + optimizer.load_state_dict(checkpoint["optimizer"]) if c.reinit_layers: raise RuntimeError - model.load_state_dict(checkpoint['model']) - except: #pylint: disable=bare-except + model.load_state_dict(checkpoint["model"]) + except: # pylint: disable=bare-except print(" > Partial model initialization.") model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint['model'], c) + model_dict = set_init_dict(model_dict, checkpoint["model"], c) model.load_state_dict(model_dict) del model_dict for group in optimizer.param_groups: - group['initial_lr'] = c.lr - print(f" > Model restored from step {checkpoint['step']:d}", - flush=True) - args.restore_step = checkpoint['step'] + group["initial_lr"] = c.lr + print(f" > Model restored from step {checkpoint['step']:d}", flush=True) + args.restore_step = checkpoint["step"] else: args.restore_step = 0 @@ -533,9 +541,7 @@ def main(args): # pylint: disable=redefined-outer-name model = DDP_th(model, device_ids=[args.rank]) if c.noam_schedule: - scheduler = NoamLR(optimizer, - warmup_steps=c.warmup_steps, - last_epoch=args.restore_step - 1) + scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1) else: scheduler = None @@ -543,16 +549,14 @@ def main(args): # pylint: disable=redefined-outer-name print("\n > Model has {} parameters".format(num_params), flush=True) if args.restore_step == 0 or not args.best_path: - best_loss = float('inf') + best_loss = float("inf") print(" > Starting with inf best loss.") else: - print(" > Restoring best loss from " - f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, - map_location='cpu')['model_loss'] + print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") + best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] print(f" > Starting with loaded last best loss {best_loss}.") - keep_all_best = c.get('keep_all_best', False) - keep_after = c.get('keep_after', 10000) # void if keep_all_best False + keep_all_best = c.get("keep_all_best", False) + keep_after = c.get("keep_after", 10000) # void if keep_all_best False # define dataloaders train_loader = setup_loader(ap, 1, is_val=False, verbose=True) @@ -562,25 +566,32 @@ def main(args): # pylint: disable=redefined-outer-name model = data_depended_init(train_loader, model) for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) - train_avg_loss_dict, global_step = train(train_loader, model, - criterion, optimizer, - scheduler, ap, global_step, - epoch) - eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, - global_step, epoch) + train_avg_loss_dict, global_step = train( + train_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch + ) + eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) - target_loss = train_avg_loss_dict['avg_loss'] + target_loss = train_avg_loss_dict["avg_loss"] if c.run_eval: - target_loss = eval_avg_loss_dict['avg_loss'] - best_loss = save_best_model(target_loss, best_loss, model, optimizer, - global_step, epoch, c.r, OUT_PATH, model_characters, - keep_all_best=keep_all_best, keep_after=keep_after) + target_loss = eval_avg_loss_dict["avg_loss"] + best_loss = save_best_model( + target_loss, + best_loss, + model, + optimizer, + global_step, + epoch, + c.r, + OUT_PATH, + model_characters, + keep_all_best=keep_all_best, + keep_after=keep_after, + ) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_arguments(sys.argv) - c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( - args, model_class='tts') + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts") try: main(args) diff --git a/TTS/bin/train_speedy_speech.py b/TTS/bin/train_speedy_speech.py index 026413bb..7959c3c1 100644 --- a/TTS/bin/train_speedy_speech.py +++ b/TTS/bin/train_speedy_speech.py @@ -10,6 +10,7 @@ from random import randrange import torch from TTS.utils.arguments import parse_arguments, process_args + # DISTRIBUTED from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.utils.data import DataLoader @@ -26,8 +27,7 @@ from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor from TTS.utils.distribute import init_distributed, reduce_tensor -from TTS.utils.generic_utils import (KeepAverage, count_parameters, - remove_experiment_folder, set_init_dict) +from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.radam import RAdam from TTS.utils.training import NoamLR, setup_torch_training_env @@ -44,10 +44,9 @@ def setup_loader(ap, r, is_val=False, verbose=False): compute_linear_spec=False, meta_data=meta_data_eval if is_val else meta_data_train, ap=ap, - tp=c.characters if 'characters' in c.keys() else None, - add_blank=c['add_blank'] if 'add_blank' in c.keys() else False, - batch_group_size=0 if is_val else c.batch_group_size * - c.batch_size, + tp=c.characters if "characters" in c.keys() else None, + add_blank=c["add_blank"] if "add_blank" in c.keys() else False, + batch_group_size=0 if is_val else c.batch_group_size * c.batch_size, min_seq_len=c.min_seq_len, max_seq_len=c.max_seq_len, phoneme_cache_path=c.phoneme_cache_path, @@ -56,7 +55,10 @@ def setup_loader(ap, r, is_val=False, verbose=False): enable_eos_bos=c.enable_eos_bos_chars, use_noise_augment=not is_val, verbose=verbose, - speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None) + speaker_mapping=speaker_mapping + if c.use_speaker_embedding and c.use_external_speaker_embedding_file + else None, + ) if c.use_phonemes and c.compute_input_seq_cache: # precompute phonemes to have a better estimate of sequence lengths. @@ -71,9 +73,9 @@ def setup_loader(ap, r, is_val=False, verbose=False): collate_fn=dataset.collate_fn, drop_last=False, sampler=sampler, - num_workers=c.num_val_loader_workers - if is_val else c.num_loader_workers, - pin_memory=False) + num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, + pin_memory=False, + ) return loader @@ -95,9 +97,7 @@ def format_data(data): speaker_c = data[8] else: # return speaker_id to be used by an embedding layer - speaker_c = [ - speaker_mapping[speaker_name] for speaker_name in speaker_names - ] + speaker_c = [speaker_mapping[speaker_name] for speaker_name in speaker_names] speaker_c = torch.LongTensor(speaker_c) else: speaker_c = None @@ -105,7 +105,7 @@ def format_data(data): durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2]) for idx, am in enumerate(attn_mask): # compute raw durations - c_idxs = am[:, :text_lengths[idx], :mel_lengths[idx]].max(1)[1] + c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1] # c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True) c_idxs, counts = torch.unique(c_idxs, return_counts=True) dur = torch.ones([text_lengths[idx]]).to(counts.dtype) @@ -115,8 +115,10 @@ def format_data(data): extra_frames = dur.sum() - mel_lengths[idx] largest_idxs = torch.argsort(-dur)[:extra_frames] dur[largest_idxs] -= 1 - assert dur.sum() == mel_lengths[idx], f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" - durations[idx, :text_lengths[idx]] = dur + assert ( + dur.sum() == mel_lengths[idx] + ), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" + durations[idx, : text_lengths[idx]] = dur # dispatch data to GPU if use_cuda: text_input = text_input.cuda(non_blocking=True) @@ -127,19 +129,27 @@ def format_data(data): speaker_c = speaker_c.cuda(non_blocking=True) attn_mask = attn_mask.cuda(non_blocking=True) durations = durations.cuda(non_blocking=True) - return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\ - avg_text_length, avg_spec_length, attn_mask, durations, item_idx + return ( + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_c, + avg_text_length, + avg_spec_length, + attn_mask, + durations, + item_idx, + ) -def train(data_loader, model, criterion, optimizer, scheduler, - ap, global_step, epoch): +def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch): model.train() epoch_time = 0 keep_avg = KeepAverage() if use_cuda: - batch_n_iter = int( - len(data_loader.dataset) / (c.batch_size * num_gpus)) + batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() @@ -149,8 +159,18 @@ def train(data_loader, model, criterion, optimizer, scheduler, start_time = time.time() # format data - text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\ - avg_text_length, avg_spec_length, _, dur_target, _ = format_data(data) + ( + text_input, + text_lengths, + mel_targets, + mel_lengths, + speaker_c, + avg_text_length, + avg_spec_length, + _, + dur_target, + _, + ) = format_data(data) loader_time = time.time() - end_time @@ -160,23 +180,24 @@ def train(data_loader, model, criterion, optimizer, scheduler, # forward pass model with torch.cuda.amp.autocast(enabled=c.mixed_precision): decoder_output, dur_output, alignments = model.forward( - text_input, text_lengths, mel_lengths, dur_target, g=speaker_c) + text_input, text_lengths, mel_lengths, dur_target, g=speaker_c + ) # compute loss - loss_dict = criterion(decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths) + loss_dict = criterion( + decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths + ) # backward pass with loss scaling if c.mixed_precision: - scaler.scale(loss_dict['loss']).backward() + scaler.scale(loss_dict["loss"]).backward() scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), - c.grad_clip) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) scaler.step(optimizer) scaler.update() else: - loss_dict['loss'].backward() - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), - c.grad_clip) + loss_dict["loss"].backward() + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) optimizer.step() # setup lr @@ -184,21 +205,21 @@ def train(data_loader, model, criterion, optimizer, scheduler, scheduler.step() # current_lr - current_lr = optimizer.param_groups[0]['lr'] + current_lr = optimizer.param_groups[0]["lr"] # compute alignment error (the lower the better ) align_error = 1 - alignment_diagonal_score(alignments, binary=True) - loss_dict['align_error'] = align_error + loss_dict["align_error"] = align_error step_time = time.time() - start_time epoch_time += step_time # aggregate losses from processes if num_gpus > 1: - loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, num_gpus) - loss_dict['loss_ssim'] = reduce_tensor(loss_dict['loss_ssim'].data, num_gpus) - loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus) - loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus) + loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus) + loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus) + loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus) + loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) # detach loss values loss_dict_new = dict() @@ -212,41 +233,43 @@ def train(data_loader, model, criterion, optimizer, scheduler, # update avg stats update_train_values = dict() for key, value in loss_dict.items(): - update_train_values['avg_' + key] = value - update_train_values['avg_loader_time'] = loader_time - update_train_values['avg_step_time'] = step_time + update_train_values["avg_" + key] = value + update_train_values["avg_loader_time"] = loader_time + update_train_values["avg_step_time"] = step_time keep_avg.update_values(update_train_values) # print training progress if global_step % c.print_step == 0: log_dict = { - "avg_spec_length": [avg_spec_length, 1], # value, precision "avg_text_length": [avg_text_length, 1], "step_time": [step_time, 4], "loader_time": [loader_time, 2], "current_lr": current_lr, } - c_logger.print_train_step(batch_n_iter, num_iter, global_step, - log_dict, loss_dict, keep_avg.avg_values) + c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) if args.rank == 0: # Plot Training Iter Stats # reduce TB load if global_step % c.tb_plot_step == 0: - iter_stats = { - "lr": current_lr, - "grad_norm": grad_norm, - "step_time": step_time - } + iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time} iter_stats.update(loss_dict) tb_logger.tb_train_iter_stats(global_step, iter_stats) if global_step % c.save_step == 0: if c.checkpoint: # save model - save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters, - model_loss=loss_dict['loss']) + save_checkpoint( + model, + optimizer, + global_step, + epoch, + 1, + OUT_PATH, + model_characters, + model_loss=loss_dict["loss"], + ) # wait all kernels to be completed torch.cuda.synchronize() @@ -267,9 +290,7 @@ def train(data_loader, model, criterion, optimizer, scheduler, # Sample audio train_audio = ap.inv_melspectrogram(pred_spec.T) - tb_logger.tb_train_audios(global_step, - {'TrainAudio': train_audio}, - c.audio["sample_rate"]) + tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, c.audio["sample_rate"]) end_time = time.time() # print epoch stats @@ -296,16 +317,18 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): start_time = time.time() # format data - text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\ - _, _, _, dur_target, _ = format_data(data) + text_input, text_lengths, mel_targets, mel_lengths, speaker_c, _, _, _, dur_target, _ = format_data(data) # forward pass model with torch.cuda.amp.autocast(enabled=c.mixed_precision): decoder_output, dur_output, alignments = model.forward( - text_input, text_lengths, mel_lengths, dur_target, g=speaker_c) + text_input, text_lengths, mel_lengths, dur_target, g=speaker_c + ) # compute loss - loss_dict = criterion(decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths) + loss_dict = criterion( + decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths + ) # step time step_time = time.time() - start_time @@ -313,14 +336,14 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): # compute alignment score align_error = 1 - alignment_diagonal_score(alignments, binary=True) - loss_dict['align_error'] = align_error + loss_dict["align_error"] = align_error # aggregate losses from processes if num_gpus > 1: - loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, num_gpus) - loss_dict['loss_ssim'] = reduce_tensor(loss_dict['loss_ssim'].data, num_gpus) - loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus) - loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus) + loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus) + loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus) + loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus) + loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) # detach loss values loss_dict_new = dict() @@ -334,7 +357,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): # update avg stats update_train_values = dict() for key, value in loss_dict.items(): - update_train_values['avg_' + key] = value + update_train_values["avg_" + key] = value keep_avg.update_values(update_train_values) if c.print_eval: @@ -350,13 +373,12 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): eval_figures = { "prediction": plot_spectrogram(pred_spec, ap, output_fig=False), "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), - "alignment": plot_alignment(align_img, output_fig=False) + "alignment": plot_alignment(align_img, output_fig=False), } # Sample audio eval_audio = ap.inv_melspectrogram(pred_spec.T) - tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, - c.audio["sample_rate"]) + tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"]) # Plot Validation Stats tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) @@ -369,7 +391,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): "Be a voice, not an echo.", "I'm sorry Dave. I'm afraid I can't do that.", "This cake is great. It's so delicious and moist.", - "Prior to November 22, 1963." + "Prior to November 22, 1963.", ] else: with open(c.test_sentences_file, "r") as f: @@ -381,7 +403,9 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): print(" | > Synthesizing test sentences") if c.use_speaker_embedding: if c.use_external_speaker_embedding_file: - speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding'] + speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]][ + "embedding" + ] speaker_id = None else: speaker_id = 0 @@ -403,25 +427,22 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): speaker_embedding=speaker_embedding, style_wav=style_wav, truncated=False, - enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument + enable_eos_bos_chars=c.enable_eos_bos_chars, # pylint: disable=unused-argument use_griffin_lim=True, - do_trim_silence=False) + do_trim_silence=False, + ) file_path = os.path.join(AUDIO_PATH, str(global_step)) os.makedirs(file_path, exist_ok=True) - file_path = os.path.join(file_path, - "TestSentence_{}.wav".format(idx)) + file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx)) ap.save_wav(wav, file_path) - test_audios['{}-audio'.format(idx)] = wav - test_figures['{}-prediction'.format(idx)] = plot_spectrogram( - postnet_output, ap) - test_figures['{}-alignment'.format(idx)] = plot_alignment( - alignment) - except: #pylint: disable=bare-except + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap) + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment) + except: # pylint: disable=bare-except print(" !! Error creating Test Sentence -", idx) traceback.print_exc() - tb_logger.tb_test_audios(global_step, test_audios, - c.audio['sample_rate']) + tb_logger.tb_test_audios(global_step, test_audios, c.audio["sample_rate"]) tb_logger.tb_test_figures(global_step, test_figures) return keep_avg.avg_values @@ -432,13 +453,12 @@ def main(args): # pylint: disable=redefined-outer-name global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping # Audio processor ap = AudioProcessor(**c.audio) - if 'characters' in c.keys(): + if "characters" in c.keys(): symbols, phonemes = make_symbols(**c.characters) # DISTRUBUTED if num_gpus > 1: - init_distributed(args.rank, num_gpus, args.group_id, - c.distributed["backend"], c.distributed["url"]) + init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) # set model characters model_characters = phonemes if c.use_phonemes else symbols @@ -448,10 +468,10 @@ def main(args): # pylint: disable=redefined-outer-name meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True) # set the portion of the data used for training if set in config.json - if 'train_portion' in c.keys(): - meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)] - if 'eval_portion' in c.keys(): - meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)] + if "train_portion" in c.keys(): + meta_data_train = meta_data_train[: int(len(meta_data_train) * c.train_portion)] + if "eval_portion" in c.keys(): + meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * c.eval_portion)] # parse speakers num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH) @@ -463,26 +483,25 @@ def main(args): # pylint: disable=redefined-outer-name if args.restore_path: print(f" > Restoring from {os.path.basename(args.restore_path)} ...") - checkpoint = torch.load(args.restore_path, map_location='cpu') + checkpoint = torch.load(args.restore_path, map_location="cpu") try: # TODO: fix optimizer init, model.cuda() needs to be called before # optimizer restore - optimizer.load_state_dict(checkpoint['optimizer']) + optimizer.load_state_dict(checkpoint["optimizer"]) if c.reinit_layers: raise RuntimeError - model.load_state_dict(checkpoint['model']) - except: #pylint: disable=bare-except + model.load_state_dict(checkpoint["model"]) + except: # pylint: disable=bare-except print(" > Partial model initialization.") model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint['model'], c) + model_dict = set_init_dict(model_dict, checkpoint["model"], c) model.load_state_dict(model_dict) del model_dict for group in optimizer.param_groups: - group['initial_lr'] = c.lr - print(" > Model restored from step %d" % checkpoint['step'], - flush=True) - args.restore_step = checkpoint['step'] + group["initial_lr"] = c.lr + print(" > Model restored from step %d" % checkpoint["step"], flush=True) + args.restore_step = checkpoint["step"] else: args.restore_step = 0 @@ -495,9 +514,7 @@ def main(args): # pylint: disable=redefined-outer-name model = DDP_th(model, device_ids=[args.rank]) if c.noam_schedule: - scheduler = NoamLR(optimizer, - warmup_steps=c.warmup_steps, - last_epoch=args.restore_step - 1) + scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1) else: scheduler = None @@ -505,16 +522,14 @@ def main(args): # pylint: disable=redefined-outer-name print("\n > Model has {} parameters".format(num_params), flush=True) if args.restore_step == 0 or not args.best_path: - best_loss = float('inf') + best_loss = float("inf") print(" > Starting with inf best loss.") else: - print(" > Restoring best loss from " - f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, - map_location='cpu')['model_loss'] + print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") + best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] print(f" > Starting with loaded last best loss {best_loss}.") - keep_all_best = c.get('keep_all_best', False) - keep_after = c.get('keep_after', 10000) # void if keep_all_best False + keep_all_best = c.get("keep_all_best", False) + keep_after = c.get("keep_after", 10000) # void if keep_all_best False # define dataloaders train_loader = setup_loader(ap, 1, is_val=False, verbose=True) @@ -523,24 +538,32 @@ def main(args): # pylint: disable=redefined-outer-name global_step = args.restore_step for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) - train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer, - scheduler, ap, global_step, - epoch) - eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, - global_step, epoch) + train_avg_loss_dict, global_step = train( + train_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch + ) + eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) - target_loss = train_avg_loss_dict['avg_loss'] + target_loss = train_avg_loss_dict["avg_loss"] if c.run_eval: - target_loss = eval_avg_loss_dict['avg_loss'] - best_loss = save_best_model(target_loss, best_loss, model, optimizer, - global_step, epoch, c.r, OUT_PATH, model_characters, - keep_all_best=keep_all_best, keep_after=keep_after) + target_loss = eval_avg_loss_dict["avg_loss"] + best_loss = save_best_model( + target_loss, + best_loss, + model, + optimizer, + global_step, + epoch, + c.r, + OUT_PATH, + model_characters, + keep_all_best=keep_all_best, + keep_after=keep_after, + ) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_arguments(sys.argv) - c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( - args, model_class='tts') + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts") try: main(args) diff --git a/TTS/bin/train_tacotron.py b/TTS/bin/train_tacotron.py index ce41980d..cf5552fc 100644 --- a/TTS/bin/train_tacotron.py +++ b/TTS/bin/train_tacotron.py @@ -22,14 +22,17 @@ from TTS.tts.utils.synthesis import synthesis from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor -from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce, - init_distributed, reduce_tensor) -from TTS.utils.generic_utils import (KeepAverage, count_parameters, - remove_experiment_folder, set_init_dict) +from TTS.utils.distribute import DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor +from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.radam import RAdam -from TTS.utils.training import (NoamLR, adam_weight_decay, check_update, - gradual_training_scheduler, set_weight_decay, - setup_torch_training_env) +from TTS.utils.training import ( + NoamLR, + adam_weight_decay, + check_update, + gradual_training_scheduler, + set_weight_decay, + setup_torch_training_env, +) use_cuda, num_gpus = setup_torch_training_env(True, False) @@ -42,13 +45,12 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None): dataset = MyDataset( r, c.text_cleaner, - compute_linear_spec=c.model.lower() == 'tacotron', + compute_linear_spec=c.model.lower() == "tacotron", meta_data=meta_data_eval if is_val else meta_data_train, ap=ap, - tp=c.characters if 'characters' in c.keys() else None, - add_blank=c['add_blank'] if 'add_blank' in c.keys() else False, - batch_group_size=0 if is_val else c.batch_group_size * - c.batch_size, + tp=c.characters if "characters" in c.keys() else None, + add_blank=c["add_blank"] if "add_blank" in c.keys() else False, + batch_group_size=0 if is_val else c.batch_group_size * c.batch_size, min_seq_len=c.min_seq_len, max_seq_len=c.max_seq_len, phoneme_cache_path=c.phoneme_cache_path, @@ -56,11 +58,10 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None): phoneme_language=c.phoneme_language, enable_eos_bos=c.enable_eos_bos_chars, verbose=verbose, - speaker_mapping=(speaker_mapping if ( - c.use_speaker_embedding - and c.use_external_speaker_embedding_file - ) else None) - ) + speaker_mapping=( + speaker_mapping if (c.use_speaker_embedding and c.use_external_speaker_embedding_file) else None + ), + ) if c.use_phonemes and c.compute_input_seq_cache: # precompute phonemes to have a better estimate of sequence lengths. @@ -75,11 +76,12 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None): collate_fn=dataset.collate_fn, drop_last=False, sampler=sampler, - num_workers=c.num_val_loader_workers - if is_val else c.num_loader_workers, - pin_memory=False) + num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, + pin_memory=False, + ) return loader + def format_data(data): # setup input data text_input = data[0] @@ -97,21 +99,16 @@ def format_data(data): speaker_embeddings = data[8] speaker_ids = None else: - speaker_ids = [ - speaker_mapping[speaker_name] for speaker_name in speaker_names - ] + speaker_ids = [speaker_mapping[speaker_name] for speaker_name in speaker_names] speaker_ids = torch.LongTensor(speaker_ids) speaker_embeddings = None else: speaker_embeddings = None speaker_ids = None - # set stop targets view, we predict a single stop token per iteration. - stop_targets = stop_targets.view(text_input.shape[0], - stop_targets.size(1) // c.r, -1) - stop_targets = (stop_targets.sum(2) > - 0.0).unsqueeze(2).float().squeeze(2) + stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1) + stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) # dispatch data to GPU if use_cuda: @@ -126,17 +123,26 @@ def format_data(data): if speaker_embeddings is not None: speaker_embeddings = speaker_embeddings.cuda(non_blocking=True) - return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length + return ( + text_input, + text_lengths, + mel_input, + mel_lengths, + linear_input, + stop_targets, + speaker_ids, + speaker_embeddings, + max_text_length, + max_spec_length, + ) -def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, - ap, global_step, epoch, scaler, scaler_st): +def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap, global_step, epoch, scaler, scaler_st): model.train() epoch_time = 0 keep_avg = KeepAverage() if use_cuda: - batch_n_iter = int( - len(data_loader.dataset) / (c.batch_size * num_gpus)) + batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() @@ -145,7 +151,18 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length = format_data(data) + ( + text_input, + text_lengths, + mel_input, + mel_lengths, + linear_input, + stop_targets, + speaker_ids, + speaker_embeddings, + max_text_length, + max_spec_length, + ) = format_data(data) loader_time = time.time() - end_time global_step += 1 @@ -161,35 +178,65 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, with torch.cuda.amp.autocast(enabled=c.mixed_precision): # forward pass model if c.bidirectional_decoder or c.double_decoder_consistency: - decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( - text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) + ( + decoder_output, + postnet_output, + alignments, + stop_tokens, + decoder_backward_output, + alignments_backward, + ) = model( + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids=speaker_ids, + speaker_embeddings=speaker_embeddings, + ) else: decoder_output, postnet_output, alignments, stop_tokens = model( - text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) + text_input, + text_lengths, + mel_input, + mel_lengths, + speaker_ids=speaker_ids, + speaker_embeddings=speaker_embeddings, + ) decoder_backward_output = None alignments_backward = None # set the [alignment] lengths wrt reduction factor for guided attention if mel_lengths.max() % model.decoder.r != 0: - alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r + alignment_lengths = ( + mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r)) + ) // model.decoder.r else: - alignment_lengths = mel_lengths // model.decoder.r + alignment_lengths = mel_lengths // model.decoder.r # compute loss - loss_dict = criterion(postnet_output, decoder_output, mel_input, - linear_input, stop_tokens, stop_targets, - mel_lengths, decoder_backward_output, - alignments, alignment_lengths, - alignments_backward, text_lengths) + loss_dict = criterion( + postnet_output, + decoder_output, + mel_input, + linear_input, + stop_tokens, + stop_targets, + mel_lengths, + decoder_backward_output, + alignments, + alignment_lengths, + alignments_backward, + text_lengths, + ) # check nan loss - if torch.isnan(loss_dict['loss']).any(): - raise RuntimeError(f'Detected NaN loss at step {global_step}.') + if torch.isnan(loss_dict["loss"]).any(): + raise RuntimeError(f"Detected NaN loss at step {global_step}.") # optimizer step if c.mixed_precision: # model optimizer step in mixed precision mode - scaler.scale(loss_dict['loss']).backward() + scaler.scale(loss_dict["loss"]).backward() scaler.unscale_(optimizer) optimizer, current_lr = adam_weight_decay(optimizer) grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True) @@ -198,7 +245,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, # stopnet optimizer step if c.separate_stopnet: - scaler_st.scale(loss_dict['stopnet_loss']).backward() + scaler_st.scale(loss_dict["stopnet_loss"]).backward() scaler.unscale_(optimizer_st) optimizer_st, _ = adam_weight_decay(optimizer_st) grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) @@ -208,14 +255,14 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, grad_norm_st = 0 else: # main model optimizer step - loss_dict['loss'].backward() + loss_dict["loss"].backward() optimizer, current_lr = adam_weight_decay(optimizer) grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True) optimizer.step() # stopnet optimizer step if c.separate_stopnet: - loss_dict['stopnet_loss'].backward() + loss_dict["stopnet_loss"].backward() optimizer_st, _ = adam_weight_decay(optimizer_st) grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0) optimizer_st.step() @@ -224,17 +271,19 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, # compute alignment error (the lower the better ) align_error = 1 - alignment_diagonal_score(alignments) - loss_dict['align_error'] = align_error + loss_dict["align_error"] = align_error step_time = time.time() - start_time epoch_time += step_time # aggregate losses from processes if num_gpus > 1: - loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus) - loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus) - loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus) - loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus) if c.stopnet else loss_dict['stopnet_loss'] + loss_dict["postnet_loss"] = reduce_tensor(loss_dict["postnet_loss"].data, num_gpus) + loss_dict["decoder_loss"] = reduce_tensor(loss_dict["decoder_loss"].data, num_gpus) + loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus) + loss_dict["stopnet_loss"] = ( + reduce_tensor(loss_dict["stopnet_loss"].data, num_gpus) if c.stopnet else loss_dict["stopnet_loss"] + ) # detach loss values loss_dict_new = dict() @@ -248,9 +297,9 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, # update avg stats update_train_values = dict() for key, value in loss_dict.items(): - update_train_values['avg_' + key] = value - update_train_values['avg_loader_time'] = loader_time - update_train_values['avg_step_time'] = step_time + update_train_values["avg_" + key] = value + update_train_values["avg_loader_time"] = loader_time + update_train_values["avg_step_time"] = step_time keep_avg.update_values(update_train_values) # print training progress @@ -262,8 +311,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, "loader_time": [loader_time, 2], "current_lr": current_lr, } - c_logger.print_train_step(batch_n_iter, num_iter, global_step, - log_dict, loss_dict, keep_avg.avg_values) + c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) if args.rank == 0: # Plot Training Iter Stats @@ -273,7 +321,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, "lr": current_lr, "grad_norm": grad_norm, "grad_norm_st": grad_norm_st, - "step_time": step_time + "step_time": step_time, } iter_stats.update(loss_dict) tb_logger.tb_train_iter_stats(global_step, iter_stats) @@ -281,17 +329,26 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, if global_step % c.save_step == 0: if c.checkpoint: # save model - save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH, - optimizer_st=optimizer_st, - model_loss=loss_dict['postnet_loss'], - characters=model_characters, - scaler=scaler.state_dict() if c.mixed_precision else None) + save_checkpoint( + model, + optimizer, + global_step, + epoch, + model.decoder.r, + OUT_PATH, + optimizer_st=optimizer_st, + model_loss=loss_dict["postnet_loss"], + characters=model_characters, + scaler=scaler.state_dict() if c.mixed_precision else None, + ) # Diagnostic visualizations const_spec = postnet_output[0].data.cpu().numpy() - gt_spec = linear_input[0].data.cpu().numpy() if c.model in [ - "Tacotron", "TacotronGST" - ] else mel_input[0].data.cpu().numpy() + gt_spec = ( + linear_input[0].data.cpu().numpy() + if c.model in ["Tacotron", "TacotronGST"] + else mel_input[0].data.cpu().numpy() + ) align_img = alignments[0].data.cpu().numpy() figures = { @@ -301,7 +358,9 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, } if c.bidirectional_decoder or c.double_decoder_consistency: - figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False) + figures["alignment_backward"] = plot_alignment( + alignments_backward[0].data.cpu().numpy(), output_fig=False + ) tb_logger.tb_train_figures(global_step, figures) @@ -310,9 +369,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, train_audio = ap.inv_spectrogram(const_spec.T) else: train_audio = ap.inv_melspectrogram(const_spec.T) - tb_logger.tb_train_audios(global_step, - {'TrainAudio': train_audio}, - c.audio["sample_rate"]) + tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, c.audio["sample_rate"]) end_time = time.time() # print epoch stats @@ -339,31 +396,62 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, _, _ = format_data(data) + ( + text_input, + text_lengths, + mel_input, + mel_lengths, + linear_input, + stop_targets, + speaker_ids, + speaker_embeddings, + _, + _, + ) = format_data(data) assert mel_input.shape[1] % model.decoder.r == 0 # forward pass model if c.bidirectional_decoder or c.double_decoder_consistency: - decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model( - text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) + ( + decoder_output, + postnet_output, + alignments, + stop_tokens, + decoder_backward_output, + alignments_backward, + ) = model( + text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings + ) else: decoder_output, postnet_output, alignments, stop_tokens = model( - text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings) + text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings + ) decoder_backward_output = None alignments_backward = None # set the alignment lengths wrt reduction factor for guided attention if mel_lengths.max() % model.decoder.r != 0: - alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r + alignment_lengths = ( + mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r)) + ) // model.decoder.r else: - alignment_lengths = mel_lengths // model.decoder.r + alignment_lengths = mel_lengths // model.decoder.r # compute loss - loss_dict = criterion(postnet_output, decoder_output, mel_input, - linear_input, stop_tokens, stop_targets, - mel_lengths, decoder_backward_output, - alignments, alignment_lengths, alignments_backward, - text_lengths) + loss_dict = criterion( + postnet_output, + decoder_output, + mel_input, + linear_input, + stop_tokens, + stop_targets, + mel_lengths, + decoder_backward_output, + alignments, + alignment_lengths, + alignments_backward, + text_lengths, + ) # step time step_time = time.time() - start_time @@ -371,14 +459,14 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): # compute alignment score align_error = 1 - alignment_diagonal_score(alignments) - loss_dict['align_error'] = align_error + loss_dict["align_error"] = align_error # aggregate losses from processes if num_gpus > 1: - loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus) - loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus) + loss_dict["postnet_loss"] = reduce_tensor(loss_dict["postnet_loss"].data, num_gpus) + loss_dict["decoder_loss"] = reduce_tensor(loss_dict["decoder_loss"].data, num_gpus) if c.stopnet: - loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus) + loss_dict["stopnet_loss"] = reduce_tensor(loss_dict["stopnet_loss"].data, num_gpus) # detach loss values loss_dict_new = dict() @@ -392,7 +480,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): # update avg stats update_train_values = dict() for key, value in loss_dict.items(): - update_train_values['avg_' + key] = value + update_train_values["avg_" + key] = value keep_avg.update_values(update_train_values) if c.print_eval: @@ -402,15 +490,17 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): # Diagnostic visualizations idx = np.random.randint(mel_input.shape[0]) const_spec = postnet_output[idx].data.cpu().numpy() - gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [ - "Tacotron", "TacotronGST" - ] else mel_input[idx].data.cpu().numpy() + gt_spec = ( + linear_input[idx].data.cpu().numpy() + if c.model in ["Tacotron", "TacotronGST"] + else mel_input[idx].data.cpu().numpy() + ) align_img = alignments[idx].data.cpu().numpy() eval_figures = { "prediction": plot_spectrogram(const_spec, ap, output_fig=False), "ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False), - "alignment": plot_alignment(align_img, output_fig=False) + "alignment": plot_alignment(align_img, output_fig=False), } # Sample audio @@ -418,14 +508,13 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): eval_audio = ap.inv_spectrogram(const_spec.T) else: eval_audio = ap.inv_melspectrogram(const_spec.T) - tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, - c.audio["sample_rate"]) + tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"]) # Plot Validation Stats if c.bidirectional_decoder or c.double_decoder_consistency: align_b_img = alignments_backward[idx].data.cpu().numpy() - eval_figures['alignment2'] = plot_alignment(align_b_img, output_fig=False) + eval_figures["alignment2"] = plot_alignment(align_b_img, output_fig=False) tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) tb_logger.tb_eval_figures(global_step, eval_figures) @@ -436,7 +525,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): "Be a voice, not an echo.", "I'm sorry Dave. I'm afraid I can't do that.", "This cake is great. It's so delicious and moist.", - "Prior to November 22, 1963." + "Prior to November 22, 1963.", ] else: with open(c.test_sentences_file, "r") as f: @@ -447,13 +536,17 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): test_figures = {} print(" | > Synthesizing test sentences") speaker_id = 0 if c.use_speaker_embedding else None - speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding'] if c.use_external_speaker_embedding_file and c.use_speaker_embedding else None + speaker_embedding = ( + speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]]["embedding"] + if c.use_external_speaker_embedding_file and c.use_speaker_embedding + else None + ) style_wav = c.get("gst_style_input") if style_wav is None and c.use_gst: # inicialize GST with zero dict. style_wav = {} print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!") - for i in range(c.gst['gst_style_tokens']): + for i in range(c.gst["gst_style_tokens"]): style_wav[str(i)] = 0 style_wav = c.get("gst_style_input") for idx, test_sentence in enumerate(test_sentences): @@ -468,25 +561,22 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch): speaker_embedding=speaker_embedding, style_wav=style_wav, truncated=False, - enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument + enable_eos_bos_chars=c.enable_eos_bos_chars, # pylint: disable=unused-argument use_griffin_lim=True, - do_trim_silence=False) + do_trim_silence=False, + ) file_path = os.path.join(AUDIO_PATH, str(global_step)) os.makedirs(file_path, exist_ok=True) - file_path = os.path.join(file_path, - "TestSentence_{}.wav".format(idx)) + file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx)) ap.save_wav(wav, file_path) - test_audios['{}-audio'.format(idx)] = wav - test_figures['{}-prediction'.format(idx)] = plot_spectrogram( - postnet_output, ap, output_fig=False) - test_figures['{}-alignment'.format(idx)] = plot_alignment( - alignment, output_fig=False) - except: #pylint: disable=bare-except + test_audios["{}-audio".format(idx)] = wav + test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap, output_fig=False) + test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False) + except: # pylint: disable=bare-except print(" !! Error creating Test Sentence -", idx) traceback.print_exc() - tb_logger.tb_test_audios(global_step, test_audios, - c.audio['sample_rate']) + tb_logger.tb_test_audios(global_step, test_audios, c.audio["sample_rate"]) tb_logger.tb_test_figures(global_step, test_figures) return keep_avg.avg_values @@ -498,13 +588,12 @@ def main(args): # pylint: disable=redefined-outer-name ap = AudioProcessor(**c.audio) # setup custom characters if set in config file. - if 'characters' in c.keys(): + if "characters" in c.keys(): symbols, phonemes = make_symbols(**c.characters) # DISTRUBUTED if num_gpus > 1: - init_distributed(args.rank, num_gpus, args.group_id, - c.distributed["backend"], c.distributed["url"]) + init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) num_chars = len(phonemes) if c.use_phonemes else len(symbols) model_characters = phonemes if c.use_phonemes else symbols @@ -512,10 +601,10 @@ def main(args): # pylint: disable=redefined-outer-name meta_data_train, meta_data_eval = load_meta_data(c.datasets) # set the portion of the data used for training - if 'train_portion' in c.keys(): - meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)] - if 'eval_portion' in c.keys(): - meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)] + if "train_portion" in c.keys(): + meta_data_train = meta_data_train[: int(len(meta_data_train) * c.train_portion)] + if "eval_portion" in c.keys(): + meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * c.eval_portion)] # parse speakers num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH) @@ -529,9 +618,7 @@ def main(args): # pylint: disable=redefined-outer-name params = set_weight_decay(model, c.wd) optimizer = RAdam(params, lr=c.lr, weight_decay=0) if c.stopnet and c.separate_stopnet: - optimizer_st = RAdam(model.decoder.stopnet.parameters(), - lr=c.lr, - weight_decay=0) + optimizer_st = RAdam(model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0) else: optimizer_st = None @@ -539,13 +626,13 @@ def main(args): # pylint: disable=redefined-outer-name criterion = TacotronLoss(c, stopnet_pos_weight=c.stopnet_pos_weight, ga_sigma=0.4) if args.restore_path: print(f" > Restoring from {os.path.basename(args.restore_path)}...") - checkpoint = torch.load(args.restore_path, map_location='cpu') + checkpoint = torch.load(args.restore_path, map_location="cpu") try: print(" > Restoring Model...") - model.load_state_dict(checkpoint['model']) + model.load_state_dict(checkpoint["model"]) # optimizer restore print(" > Restoring Optimizer...") - optimizer.load_state_dict(checkpoint['optimizer']) + optimizer.load_state_dict(checkpoint["optimizer"]) if "scaler" in checkpoint and c.mixed_precision: print(" > Restoring AMP Scaler...") scaler.load_state_dict(checkpoint["scaler"]) @@ -554,17 +641,16 @@ def main(args): # pylint: disable=redefined-outer-name except (KeyError, RuntimeError): print(" > Partial model initialization...") model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint['model'], c) + model_dict = set_init_dict(model_dict, checkpoint["model"], c) # torch.save(model_dict, os.path.join(OUT_PATH, 'state_dict.pt')) # print("State Dict saved for debug in: ", os.path.join(OUT_PATH, 'state_dict.pt')) model.load_state_dict(model_dict) del model_dict for group in optimizer.param_groups: - group['lr'] = c.lr - print(" > Model restored from step %d" % checkpoint['step'], - flush=True) - args.restore_step = checkpoint['step'] + group["lr"] = c.lr + print(" > Model restored from step %d" % checkpoint["step"], flush=True) + args.restore_step = checkpoint["step"] else: args.restore_step = 0 @@ -577,9 +663,7 @@ def main(args): # pylint: disable=redefined-outer-name model = apply_gradient_allreduce(model) if c.noam_schedule: - scheduler = NoamLR(optimizer, - warmup_steps=c.warmup_steps, - last_epoch=args.restore_step - 1) + scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1) else: scheduler = None @@ -587,22 +671,17 @@ def main(args): # pylint: disable=redefined-outer-name print("\n > Model has {} parameters".format(num_params), flush=True) if args.restore_step == 0 or not args.best_path: - best_loss = float('inf') + best_loss = float("inf") print(" > Starting with inf best loss.") else: - print(" > Restoring best loss from " - f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, - map_location='cpu')['model_loss'] + print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") + best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] print(f" > Starting with loaded last best loss {best_loss}.") - keep_all_best = c.get('keep_all_best', False) - keep_after = c.get('keep_after', 10000) # void if keep_all_best False + keep_all_best = c.get("keep_all_best", False) + keep_after = c.get("keep_after", 10000) # void if keep_all_best False # define data loaders - train_loader = setup_loader(ap, - model.decoder.r, - is_val=False, - verbose=True) + train_loader = setup_loader(ap, model.decoder.r, is_val=False, verbose=True) eval_loader = setup_loader(ap, model.decoder.r, is_val=True) global_step = args.restore_step @@ -617,28 +696,29 @@ def main(args): # pylint: disable=redefined-outer-name model.decoder_backward.set_r(r) train_loader.dataset.outputs_per_step = r eval_loader.dataset.outputs_per_step = r - train_loader = setup_loader(ap, - model.decoder.r, - is_val=False, - dataset=train_loader.dataset) - eval_loader = setup_loader(ap, - model.decoder.r, - is_val=True, - dataset=eval_loader.dataset) + train_loader = setup_loader(ap, model.decoder.r, is_val=False, dataset=train_loader.dataset) + eval_loader = setup_loader(ap, model.decoder.r, is_val=True, dataset=eval_loader.dataset) print("\n > Number of output frames:", model.decoder.r) # train one epoch - train_avg_loss_dict, global_step = train(train_loader, model, - criterion, optimizer, - optimizer_st, scheduler, ap, - global_step, epoch, scaler, - scaler_st) + train_avg_loss_dict, global_step = train( + train_loader, + model, + criterion, + optimizer, + optimizer_st, + scheduler, + ap, + global_step, + epoch, + scaler, + scaler_st, + ) # eval one epoch - eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, - global_step, epoch) + eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) - target_loss = train_avg_loss_dict['avg_postnet_loss'] + target_loss = train_avg_loss_dict["avg_postnet_loss"] if c.run_eval: - target_loss = eval_avg_loss_dict['avg_postnet_loss'] + target_loss = eval_avg_loss_dict["avg_postnet_loss"] best_loss = save_best_model( target_loss, best_loss, @@ -651,14 +731,13 @@ def main(args): # pylint: disable=redefined-outer-name model_characters, keep_all_best=keep_all_best, keep_after=keep_after, - scaler=scaler.state_dict() if c.mixed_precision else None + scaler=scaler.state_dict() if c.mixed_precision else None, ) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_arguments(sys.argv) - c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( - args, model_class='tts') + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts") try: main(args) diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py index 9c0764fb..628a1f4c 100644 --- a/TTS/bin/train_vocoder_gan.py +++ b/TTS/bin/train_vocoder_gan.py @@ -12,8 +12,7 @@ import torch from torch.utils.data import DataLoader from TTS.utils.arguments import parse_arguments, process_args from TTS.utils.audio import AudioProcessor -from TTS.utils.generic_utils import (KeepAverage, count_parameters, - remove_experiment_folder, set_init_dict) +from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.radam import RAdam @@ -21,8 +20,7 @@ from TTS.utils.training import setup_torch_training_env from TTS.vocoder.datasets.gan_dataset import GANDataset from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss -from TTS.vocoder.utils.generic_utils import (plot_results, setup_discriminator, - setup_generator) +from TTS.vocoder.utils.generic_utils import plot_results, setup_discriminator, setup_generator from TTS.vocoder.utils.io import save_best_model, save_checkpoint # DISTRIBUTED @@ -36,27 +34,30 @@ use_cuda, num_gpus = setup_torch_training_env(True, True) def setup_loader(ap, is_val=False, verbose=False): loader = None if not is_val or c.run_eval: - dataset = GANDataset(ap=ap, - items=eval_data if is_val else train_data, - seq_len=c.seq_len, - hop_len=ap.hop_length, - pad_short=c.pad_short, - conv_pad=c.conv_pad, - is_training=not is_val, - return_segments=not is_val, - use_noise_augment=c.use_noise_augment, - use_cache=c.use_cache, - verbose=verbose) + dataset = GANDataset( + ap=ap, + items=eval_data if is_val else train_data, + seq_len=c.seq_len, + hop_len=ap.hop_length, + pad_short=c.pad_short, + conv_pad=c.conv_pad, + is_training=not is_val, + return_segments=not is_val, + use_noise_augment=c.use_noise_augment, + use_cache=c.use_cache, + verbose=verbose, + ) dataset.shuffle_mapping() sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None - loader = DataLoader(dataset, - batch_size=1 if is_val else c.batch_size, - shuffle=num_gpus == 0, - drop_last=False, - sampler=sampler, - num_workers=c.num_val_loader_workers - if is_val else c.num_loader_workers, - pin_memory=False) + loader = DataLoader( + dataset, + batch_size=1 if is_val else c.batch_size, + shuffle=num_gpus == 0, + drop_last=False, + sampler=sampler, + num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, + pin_memory=False, + ) return loader @@ -83,16 +84,26 @@ def format_data(data): return co, x, None, None -def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, - scheduler_G, scheduler_D, ap, global_step, epoch): +def train( + model_G, + criterion_G, + optimizer_G, + model_D, + criterion_D, + optimizer_D, + scheduler_G, + scheduler_D, + ap, + global_step, + epoch, +): data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) model_G.train() model_D.train() epoch_time = 0 keep_avg = KeepAverage() if use_cuda: - batch_n_iter = int( - len(data_loader.dataset) / (c.batch_size * num_gpus)) + batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() @@ -148,16 +159,14 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, scores_fake = D_out_fake # compute losses - loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, - feats_real, y_hat_sub, y_G_sub) - loss_G = loss_G_dict['G_loss'] + loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, feats_real, y_hat_sub, y_G_sub) + loss_G = loss_G_dict["G_loss"] # optimizer generator optimizer_G.zero_grad() loss_G.backward() if c.gen_clip_grad > 0: - torch.nn.utils.clip_grad_norm_(model_G.parameters(), - c.gen_clip_grad) + torch.nn.utils.clip_grad_norm_(model_G.parameters(), c.gen_clip_grad) optimizer_G.step() if scheduler_G is not None: scheduler_G.step() @@ -202,14 +211,13 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, # compute losses loss_D_dict = criterion_D(scores_fake, scores_real) - loss_D = loss_D_dict['D_loss'] + loss_D = loss_D_dict["D_loss"] # optimizer discriminator optimizer_D.zero_grad() loss_D.backward() if c.disc_clip_grad > 0: - torch.nn.utils.clip_grad_norm_(model_D.parameters(), - c.disc_clip_grad) + torch.nn.utils.clip_grad_norm_(model_D.parameters(), c.disc_clip_grad) optimizer_D.step() if scheduler_D is not None: scheduler_D.step() @@ -224,36 +232,31 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, epoch_time += step_time # get current learning rates - current_lr_G = list(optimizer_G.param_groups)[0]['lr'] - current_lr_D = list(optimizer_D.param_groups)[0]['lr'] + current_lr_G = list(optimizer_G.param_groups)[0]["lr"] + current_lr_D = list(optimizer_D.param_groups)[0]["lr"] # update avg stats update_train_values = dict() for key, value in loss_dict.items(): - update_train_values['avg_' + key] = value - update_train_values['avg_loader_time'] = loader_time - update_train_values['avg_step_time'] = step_time + update_train_values["avg_" + key] = value + update_train_values["avg_loader_time"] = loader_time + update_train_values["avg_step_time"] = step_time keep_avg.update_values(update_train_values) # print training stats if global_step % c.print_step == 0: log_dict = { - 'step_time': [step_time, 2], - 'loader_time': [loader_time, 4], + "step_time": [step_time, 2], + "loader_time": [loader_time, 4], "current_lr_G": current_lr_G, - "current_lr_D": current_lr_D + "current_lr_D": current_lr_D, } - c_logger.print_train_step(batch_n_iter, num_iter, global_step, - log_dict, loss_dict, keep_avg.avg_values) + c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) if args.rank == 0: # plot step stats if global_step % 10 == 0: - iter_stats = { - "lr_G": current_lr_G, - "lr_D": current_lr_D, - "step_time": step_time - } + iter_stats = {"lr_G": current_lr_G, "lr_D": current_lr_D, "step_time": step_time} iter_stats.update(loss_dict) tb_logger.tb_train_iter_stats(global_step, iter_stats) @@ -261,27 +264,26 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D, if global_step % c.save_step == 0: if c.checkpoint: # save model - save_checkpoint(model_G, - optimizer_G, - scheduler_G, - model_D, - optimizer_D, - scheduler_D, - global_step, - epoch, - OUT_PATH, - model_losses=loss_dict) + save_checkpoint( + model_G, + optimizer_G, + scheduler_G, + model_D, + optimizer_D, + scheduler_D, + global_step, + epoch, + OUT_PATH, + model_losses=loss_dict, + ) # compute spectrograms - figures = plot_results(y_hat_vis, y_G, ap, global_step, - 'train') + figures = plot_results(y_hat_vis, y_G, ap, global_step, "train") tb_logger.tb_train_figures(global_step, figures) # Sample audio sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy() - tb_logger.tb_train_audios(global_step, - {'train/audio': sample_voice}, - c.audio["sample_rate"]) + tb_logger.tb_train_audios(global_step, {"train/audio": sample_voice}, c.audio["sample_rate"]) end_time = time.time() # print epoch stats @@ -356,8 +358,7 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) feats_fake, feats_real = None, None # compute losses - loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, - feats_real, y_hat_sub, y_G_sub) + loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, feats_real, y_hat_sub, y_G_sub) loss_dict = dict() for key, value in loss_G_dict.items(): @@ -413,9 +414,9 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) # update avg stats update_eval_values = dict() for key, value in loss_dict.items(): - update_eval_values['avg_' + key] = value - update_eval_values['avg_loader_time'] = loader_time - update_eval_values['avg_step_time'] = step_time + update_eval_values["avg_" + key] = value + update_eval_values["avg_loader_time"] = loader_time + update_eval_values["avg_step_time"] = step_time keep_avg.update_values(update_eval_values) # print eval stats @@ -424,13 +425,12 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch) if args.rank == 0: # compute spectrograms - figures = plot_results(y_hat, y_G, ap, global_step, 'eval') + figures = plot_results(y_hat, y_G, ap, global_step, "eval") tb_logger.tb_eval_figures(global_step, figures) # Sample audio sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy() - tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice}, - c.audio["sample_rate"]) + tb_logger.tb_eval_audios(global_step, {"eval/audio": sample_voice}, c.audio["sample_rate"]) tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) @@ -446,8 +446,7 @@ def main(args): # pylint: disable=redefined-outer-name print(f" > Loading wavs from: {c.data_path}") if c.feature_path is not None: print(f" > Loading features from: {c.feature_path}") - eval_data, train_data = load_wav_feat_data( - c.data_path, c.feature_path, c.eval_split_size) + eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size) else: eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) @@ -456,8 +455,7 @@ def main(args): # pylint: disable=redefined-outer-name # DISTRUBUTED if num_gpus > 1: - init_distributed(args.rank, num_gpus, args.group_id, - c.distributed["backend"], c.distributed["url"]) + init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) # setup models model_gen = setup_generator(c) @@ -465,21 +463,17 @@ def main(args): # pylint: disable=redefined-outer-name # setup optimizers optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0) - optimizer_disc = RAdam(model_disc.parameters(), - lr=c.lr_disc, - weight_decay=0) + optimizer_disc = RAdam(model_disc.parameters(), lr=c.lr_disc, weight_decay=0) # schedulers scheduler_gen = None scheduler_disc = None - if 'lr_scheduler_gen' in c: + if "lr_scheduler_gen" in c: scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen) - scheduler_gen = scheduler_gen( - optimizer_gen, **c.lr_scheduler_gen_params) - if 'lr_scheduler_disc' in c: + scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params) + if "lr_scheduler_disc" in c: scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc) - scheduler_disc = scheduler_disc( - optimizer_disc, **c.lr_scheduler_disc_params) + scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params) # setup criterion criterion_gen = GeneratorLoss(c) @@ -487,47 +481,46 @@ def main(args): # pylint: disable=redefined-outer-name if args.restore_path: print(f" > Restoring from {os.path.basename(args.restore_path)}...") - checkpoint = torch.load(args.restore_path, map_location='cpu') + checkpoint = torch.load(args.restore_path, map_location="cpu") try: print(" > Restoring Generator Model...") - model_gen.load_state_dict(checkpoint['model']) + model_gen.load_state_dict(checkpoint["model"]) print(" > Restoring Generator Optimizer...") - optimizer_gen.load_state_dict(checkpoint['optimizer']) + optimizer_gen.load_state_dict(checkpoint["optimizer"]) print(" > Restoring Discriminator Model...") - model_disc.load_state_dict(checkpoint['model_disc']) + model_disc.load_state_dict(checkpoint["model_disc"]) print(" > Restoring Discriminator Optimizer...") - optimizer_disc.load_state_dict(checkpoint['optimizer_disc']) - if 'scheduler' in checkpoint and scheduler_gen is not None: + optimizer_disc.load_state_dict(checkpoint["optimizer_disc"]) + if "scheduler" in checkpoint and scheduler_gen is not None: print(" > Restoring Generator LR Scheduler...") - scheduler_gen.load_state_dict(checkpoint['scheduler']) + scheduler_gen.load_state_dict(checkpoint["scheduler"]) # NOTE: Not sure if necessary scheduler_gen.optimizer = optimizer_gen - if 'scheduler_disc' in checkpoint and scheduler_disc is not None: + if "scheduler_disc" in checkpoint and scheduler_disc is not None: print(" > Restoring Discriminator LR Scheduler...") - scheduler_disc.load_state_dict(checkpoint['scheduler_disc']) + scheduler_disc.load_state_dict(checkpoint["scheduler_disc"]) scheduler_disc.optimizer = optimizer_disc except RuntimeError: # restore only matching layers. print(" > Partial model initialization...") model_dict = model_gen.state_dict() - model_dict = set_init_dict(model_dict, checkpoint['model'], c) + model_dict = set_init_dict(model_dict, checkpoint["model"], c) model_gen.load_state_dict(model_dict) model_dict = model_disc.state_dict() - model_dict = set_init_dict(model_dict, checkpoint['model_disc'], c) + model_dict = set_init_dict(model_dict, checkpoint["model_disc"], c) model_disc.load_state_dict(model_dict) del model_dict # reset lr if not countinuining training. for group in optimizer_gen.param_groups: - group['lr'] = c.lr_gen + group["lr"] = c.lr_gen for group in optimizer_disc.param_groups: - group['lr'] = c.lr_disc + group["lr"] = c.lr_disc - print(f" > Model restored from step {checkpoint['step']:d}", - flush=True) - args.restore_step = checkpoint['step'] + print(f" > Model restored from step {checkpoint['step']:d}", flush=True) + args.restore_step = checkpoint["step"] else: args.restore_step = 0 @@ -548,50 +541,55 @@ def main(args): # pylint: disable=redefined-outer-name print(" > Discriminator has {} parameters".format(num_params), flush=True) if args.restore_step == 0 or not args.best_path: - best_loss = float('inf') + best_loss = float("inf") print(" > Starting with inf best loss.") else: - print(" > Restoring best loss from " - f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, - map_location='cpu')['model_loss'] + print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") + best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] print(f" > Starting with best loss of {best_loss}.") - keep_all_best = c.get('keep_all_best', False) - keep_after = c.get('keep_after', 10000) # void if keep_all_best False + keep_all_best = c.get("keep_all_best", False) + keep_after = c.get("keep_after", 10000) # void if keep_all_best False global_step = args.restore_step for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) - _, global_step = train(model_gen, criterion_gen, optimizer_gen, - model_disc, criterion_disc, optimizer_disc, - scheduler_gen, scheduler_disc, ap, global_step, - epoch) - eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, - criterion_disc, ap, - global_step, epoch) + _, global_step = train( + model_gen, + criterion_gen, + optimizer_gen, + model_disc, + criterion_disc, + optimizer_disc, + scheduler_gen, + scheduler_disc, + ap, + global_step, + epoch, + ) + eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = eval_avg_loss_dict[c.target_loss] - best_loss = save_best_model(target_loss, - best_loss, - model_gen, - optimizer_gen, - scheduler_gen, - model_disc, - optimizer_disc, - scheduler_disc, - global_step, - epoch, - OUT_PATH, - keep_all_best=keep_all_best, - keep_after=keep_after, - model_losses=eval_avg_loss_dict, - ) + best_loss = save_best_model( + target_loss, + best_loss, + model_gen, + optimizer_gen, + scheduler_gen, + model_disc, + optimizer_disc, + scheduler_disc, + global_step, + epoch, + OUT_PATH, + keep_all_best=keep_all_best, + keep_after=keep_after, + model_losses=eval_avg_loss_dict, + ) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_arguments(sys.argv) - c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( - args, model_class='vocoder') + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="vocoder") try: main(args) diff --git a/TTS/bin/train_vocoder_wavegrad.py b/TTS/bin/train_vocoder_wavegrad.py index 68d76598..e5ade7b8 100644 --- a/TTS/bin/train_vocoder_wavegrad.py +++ b/TTS/bin/train_vocoder_wavegrad.py @@ -8,6 +8,7 @@ import traceback import numpy as np import torch + # DISTRIBUTED from torch.nn.parallel import DistributedDataParallel as DDP_th from torch.optim import Adam @@ -16,8 +17,7 @@ from torch.utils.data.distributed import DistributedSampler from TTS.utils.arguments import parse_arguments, process_args from TTS.utils.audio import AudioProcessor from TTS.utils.distribute import init_distributed -from TTS.utils.generic_utils import (KeepAverage, count_parameters, - remove_experiment_folder, set_init_dict) +from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict from TTS.utils.training import setup_torch_training_env from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset @@ -31,27 +31,29 @@ def setup_loader(ap, is_val=False, verbose=False): if is_val and not c.run_eval: loader = None else: - dataset = WaveGradDataset(ap=ap, - items=eval_data if is_val else train_data, - seq_len=c.seq_len, - hop_len=ap.hop_length, - pad_short=c.pad_short, - conv_pad=c.conv_pad, - is_training=not is_val, - return_segments=True, - use_noise_augment=False, - use_cache=c.use_cache, - verbose=verbose) + dataset = WaveGradDataset( + ap=ap, + items=eval_data if is_val else train_data, + seq_len=c.seq_len, + hop_len=ap.hop_length, + pad_short=c.pad_short, + conv_pad=c.conv_pad, + is_training=not is_val, + return_segments=True, + use_noise_augment=False, + use_cache=c.use_cache, + verbose=verbose, + ) sampler = DistributedSampler(dataset) if num_gpus > 1 else None - loader = DataLoader(dataset, - batch_size=c.batch_size, - shuffle=num_gpus <= 1, - drop_last=False, - sampler=sampler, - num_workers=c.num_val_loader_workers - if is_val else c.num_loader_workers, - pin_memory=False) - + loader = DataLoader( + dataset, + batch_size=c.batch_size, + shuffle=num_gpus <= 1, + drop_last=False, + sampler=sampler, + num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, + pin_memory=False, + ) return loader @@ -77,24 +79,21 @@ def format_test_data(data): return m, x -def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, - epoch): +def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch): data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0)) model.train() epoch_time = 0 keep_avg = KeepAverage() if use_cuda: - batch_n_iter = int( - len(data_loader.dataset) / (c.batch_size * num_gpus)) + batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() c_logger.print_train_start() # setup noise schedule - noise_schedule = c['train_noise_schedule'] - betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'], - noise_schedule['num_steps']) - if hasattr(model, 'module'): + noise_schedule = c["train_noise_schedule"] + betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) + if hasattr(model, "module"): model.module.compute_noise_level(betas) else: model.compute_noise_level(betas) @@ -109,7 +108,7 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, with torch.cuda.amp.autocast(enabled=c.mixed_precision): # compute noisy input - if hasattr(model, 'module'): + if hasattr(model, "module"): noise, x_noisy, noise_scale = model.module.compute_y_n(x) else: noise, x_noisy, noise_scale = model.compute_y_n(x) @@ -119,11 +118,11 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, # compute losses loss = criterion(noise, noise_hat) - loss_wavegrad_dict = {'wavegrad_loss': loss} + loss_wavegrad_dict = {"wavegrad_loss": loss} # check nan loss if torch.isnan(loss).any(): - raise RuntimeError(f'Detected NaN loss at step {global_step}.') + raise RuntimeError(f"Detected NaN loss at step {global_step}.") optimizer.zero_grad() @@ -131,14 +130,12 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, if c.mixed_precision: scaler.scale(loss).backward() scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), - c.clip_grad) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.clip_grad) scaler.step(optimizer) scaler.update() else: loss.backward() - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), - c.clip_grad) + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.clip_grad) optimizer.step() # schedule update @@ -158,35 +155,30 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch_time += step_time # get current learning rates - current_lr = list(optimizer.param_groups)[0]['lr'] + current_lr = list(optimizer.param_groups)[0]["lr"] # update avg stats update_train_values = dict() for key, value in loss_dict.items(): - update_train_values['avg_' + key] = value - update_train_values['avg_loader_time'] = loader_time - update_train_values['avg_step_time'] = step_time + update_train_values["avg_" + key] = value + update_train_values["avg_loader_time"] = loader_time + update_train_values["avg_step_time"] = step_time keep_avg.update_values(update_train_values) # print training stats if global_step % c.print_step == 0: log_dict = { - 'step_time': [step_time, 2], - 'loader_time': [loader_time, 4], + "step_time": [step_time, 2], + "loader_time": [loader_time, 4], "current_lr": current_lr, - "grad_norm": grad_norm.item() + "grad_norm": grad_norm.item(), } - c_logger.print_train_step(batch_n_iter, num_iter, global_step, - log_dict, loss_dict, keep_avg.avg_values) + c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values) if args.rank == 0: # plot step stats if global_step % 10 == 0: - iter_stats = { - "lr": current_lr, - "grad_norm": grad_norm.item(), - "step_time": step_time - } + iter_stats = {"lr": current_lr, "grad_norm": grad_norm.item(), "step_time": step_time} iter_stats.update(loss_dict) tb_logger.tb_train_iter_stats(global_step, iter_stats) @@ -205,7 +197,7 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch, OUT_PATH, model_losses=loss_dict, - scaler=scaler.state_dict() if c.mixed_precision else None + scaler=scaler.state_dict() if c.mixed_precision else None, ) end_time = time.time() @@ -242,19 +234,17 @@ def evaluate(model, criterion, ap, global_step, epoch): global_step += 1 # compute noisy input - if hasattr(model, 'module'): + if hasattr(model, "module"): noise, x_noisy, noise_scale = model.module.compute_y_n(x) else: noise, x_noisy, noise_scale = model.compute_y_n(x) - # forward pass noise_hat = model(x_noisy, m, noise_scale) # compute losses loss = criterion(noise, noise_hat) - loss_wavegrad_dict = {'wavegrad_loss': loss} - + loss_wavegrad_dict = {"wavegrad_loss": loss} loss_dict = dict() for key, value in loss_wavegrad_dict.items(): @@ -269,9 +259,9 @@ def evaluate(model, criterion, ap, global_step, epoch): # update avg stats update_eval_values = dict() for key, value in loss_dict.items(): - update_eval_values['avg_' + key] = value - update_eval_values['avg_loader_time'] = loader_time - update_eval_values['avg_step_time'] = step_time + update_eval_values["avg_" + key] = value + update_eval_values["avg_loader_time"] = loader_time + update_eval_values["avg_step_time"] = step_time keep_avg.update_values(update_eval_values) # print eval stats @@ -284,11 +274,9 @@ def evaluate(model, criterion, ap, global_step, epoch): m, x = format_test_data(samples[0]) # setup noise schedule and inference - noise_schedule = c['test_noise_schedule'] - betas = np.linspace(noise_schedule['min_val'], - noise_schedule['max_val'], - noise_schedule['num_steps']) - if hasattr(model, 'module'): + noise_schedule = c["test_noise_schedule"] + betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"]) + if hasattr(model, "module"): model.module.compute_noise_level(betas) # compute voice x_pred = model.module.inference(m) @@ -298,13 +286,12 @@ def evaluate(model, criterion, ap, global_step, epoch): x_pred = model.inference(m) # compute spectrograms - figures = plot_results(x_pred, x, ap, global_step, 'eval') + figures = plot_results(x_pred, x, ap, global_step, "eval") tb_logger.tb_eval_figures(global_step, figures) # Sample audio sample_voice = x_pred[0].squeeze(0).detach().cpu().numpy() - tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice}, - c.audio["sample_rate"]) + tb_logger.tb_eval_audios(global_step, {"eval/audio": sample_voice}, c.audio["sample_rate"]) tb_logger.tb_eval_stats(global_step, keep_avg.avg_values) data_loader.dataset.return_segments = True @@ -318,8 +305,7 @@ def main(args): # pylint: disable=redefined-outer-name print(f" > Loading wavs from: {c.data_path}") if c.feature_path is not None: print(f" > Loading features from: {c.feature_path}") - eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, - c.eval_split_size) + eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size) else: eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) @@ -328,8 +314,7 @@ def main(args): # pylint: disable=redefined-outer-name # DISTRUBUTED if num_gpus > 1: - init_distributed(args.rank, num_gpus, args.group_id, - c.distributed["backend"], c.distributed["url"]) + init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"]) # setup models model = setup_generator(c) @@ -342,7 +327,7 @@ def main(args): # pylint: disable=redefined-outer-name # schedulers scheduler = None - if 'lr_scheduler' in c: + if "lr_scheduler" in c: scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler) scheduler = scheduler(optimizer, **c.lr_scheduler_params) @@ -355,15 +340,15 @@ def main(args): # pylint: disable=redefined-outer-name if args.restore_path: print(f" > Restoring from {os.path.basename(args.restore_path)}...") - checkpoint = torch.load(args.restore_path, map_location='cpu') + checkpoint = torch.load(args.restore_path, map_location="cpu") try: print(" > Restoring Model...") - model.load_state_dict(checkpoint['model']) + model.load_state_dict(checkpoint["model"]) print(" > Restoring Optimizer...") - optimizer.load_state_dict(checkpoint['optimizer']) - if 'scheduler' in checkpoint: + optimizer.load_state_dict(checkpoint["optimizer"]) + if "scheduler" in checkpoint: print(" > Restoring LR Scheduler...") - scheduler.load_state_dict(checkpoint['scheduler']) + scheduler.load_state_dict(checkpoint["scheduler"]) # NOTE: Not sure if necessary scheduler.optimizer = optimizer if "scaler" in checkpoint and c.mixed_precision: @@ -373,17 +358,16 @@ def main(args): # pylint: disable=redefined-outer-name # retore only matching layers. print(" > Partial model initialization...") model_dict = model.state_dict() - model_dict = set_init_dict(model_dict, checkpoint['model'], c) + model_dict = set_init_dict(model_dict, checkpoint["model"], c) model.load_state_dict(model_dict) del model_dict # reset lr if not countinuining training. for group in optimizer.param_groups: - group['lr'] = c.lr + group["lr"] = c.lr - print(" > Model restored from step %d" % checkpoint['step'], - flush=True) - args.restore_step = checkpoint['step'] + print(" > Model restored from step %d" % checkpoint["step"], flush=True) + args.restore_step = checkpoint["step"] else: args.restore_step = 0 @@ -395,22 +379,19 @@ def main(args): # pylint: disable=redefined-outer-name print(" > WaveGrad has {} parameters".format(num_params), flush=True) if args.restore_step == 0 or not args.best_path: - best_loss = float('inf') + best_loss = float("inf") print(" > Starting with inf best loss.") else: - print(" > Restoring best loss from " - f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, - map_location='cpu')['model_loss'] + print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") + best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] print(f" > Starting with loaded last best loss {best_loss}.") - keep_all_best = c.get('keep_all_best', False) - keep_after = c.get('keep_after', 10000) # void if keep_all_best False + keep_all_best = c.get("keep_all_best", False) + keep_after = c.get("keep_after", 10000) # void if keep_all_best False global_step = args.restore_step for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) - _, global_step = train(model, criterion, optimizer, scheduler, scaler, - ap, global_step, epoch) + _, global_step = train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch) eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = eval_avg_loss_dict[c.target_loss] @@ -429,14 +410,13 @@ def main(args): # pylint: disable=redefined-outer-name keep_all_best=keep_all_best, keep_after=keep_after, model_losses=eval_avg_loss_dict, - scaler=scaler.state_dict() if c.mixed_precision else None + scaler=scaler.state_dict() if c.mixed_precision else None, ) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_arguments(sys.argv) - c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( - args, model_class='vocoder') + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="vocoder") try: main(args) diff --git a/TTS/bin/train_vocoder_wavernn.py b/TTS/bin/train_vocoder_wavernn.py index 6b75405a..25129883 100644 --- a/TTS/bin/train_vocoder_wavernn.py +++ b/TTS/bin/train_vocoder_wavernn.py @@ -24,10 +24,7 @@ from TTS.utils.generic_utils import ( set_init_dict, ) from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset -from TTS.vocoder.datasets.preprocess import ( - load_wav_data, - load_wav_feat_data -) +from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss from TTS.vocoder.utils.generic_utils import setup_generator from TTS.vocoder.utils.io import save_best_model, save_checkpoint @@ -40,26 +37,26 @@ def setup_loader(ap, is_val=False, verbose=False): if is_val and not c.run_eval: loader = None else: - dataset = WaveRNNDataset(ap=ap, - items=eval_data if is_val else train_data, - seq_len=c.seq_len, - hop_len=ap.hop_length, - pad=c.padding, - mode=c.mode, - mulaw=c.mulaw, - is_training=not is_val, - verbose=verbose, - ) + dataset = WaveRNNDataset( + ap=ap, + items=eval_data if is_val else train_data, + seq_len=c.seq_len, + hop_len=ap.hop_length, + pad=c.padding, + mode=c.mode, + mulaw=c.mulaw, + is_training=not is_val, + verbose=verbose, + ) # sampler = DistributedSampler(dataset) if num_gpus > 1 else None - loader = DataLoader(dataset, - shuffle=True, - collate_fn=dataset.collate, - batch_size=c.batch_size, - num_workers=c.num_val_loader_workers - if is_val - else c.num_loader_workers, - pin_memory=True, - ) + loader = DataLoader( + dataset, + shuffle=True, + collate_fn=dataset.collate, + batch_size=c.batch_size, + num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers, + pin_memory=True, + ) return loader @@ -85,8 +82,7 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch epoch_time = 0 keep_avg = KeepAverage() if use_cuda: - batch_n_iter = int(len(data_loader.dataset) / - (c.batch_size * num_gpus)) + batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus)) else: batch_n_iter = int(len(data_loader.dataset) / c.batch_size) end_time = time.time() @@ -114,8 +110,7 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch scaler.scale(loss).backward() scaler.unscale_(optimizer) if c.grad_clip > 0: - torch.nn.utils.clip_grad_norm_( - model.parameters(), c.grad_clip) + torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) scaler.step(optimizer) scaler.update() else: @@ -132,8 +127,7 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch raise RuntimeError(" [!] None loss. Exiting ...") loss.backward() if c.grad_clip > 0: - torch.nn.utils.clip_grad_norm_( - model.parameters(), c.grad_clip) + torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip) optimizer.step() if scheduler is not None: @@ -156,17 +150,19 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch # print training stats if global_step % c.print_step == 0: - log_dict = {"step_time": [step_time, 2], - "loader_time": [loader_time, 4], - "current_lr": cur_lr, - } - c_logger.print_train_step(batch_n_iter, - num_iter, - global_step, - log_dict, - loss_dict, - keep_avg.avg_values, - ) + log_dict = { + "step_time": [step_time, 2], + "loader_time": [loader_time, 4], + "current_lr": cur_lr, + } + c_logger.print_train_step( + batch_n_iter, + num_iter, + global_step, + log_dict, + loss_dict, + keep_avg.avg_values, + ) # plot step stats if global_step % 10 == 0: @@ -189,36 +185,36 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch epoch, OUT_PATH, model_losses=loss_dict, - scaler=scaler.state_dict() if c.mixed_precision else None + scaler=scaler.state_dict() if c.mixed_precision else None, ) # synthesize a full voice rand_idx = random.randrange(0, len(train_data)) - wav_path = train_data[rand_idx] if not isinstance( - train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0] + wav_path = ( + train_data[rand_idx] if not isinstance(train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0] + ) wav = ap.load_wav(wav_path) ground_mel = ap.melspectrogram(wav) ground_mel = torch.FloatTensor(ground_mel) if use_cuda: ground_mel = ground_mel.cuda(non_blocking=True) - sample_wav = model.inference(ground_mel, - c.batched, - c.target_samples, - c.overlap_samples, - ) + sample_wav = model.inference( + ground_mel, + c.batched, + c.target_samples, + c.overlap_samples, + ) predict_mel = ap.melspectrogram(sample_wav) # compute spectrograms - figures = {"train/ground_truth": plot_spectrogram(ground_mel.T), - "train/prediction": plot_spectrogram(predict_mel.T) - } + figures = { + "train/ground_truth": plot_spectrogram(ground_mel.T), + "train/prediction": plot_spectrogram(predict_mel.T), + } tb_logger.tb_train_figures(global_step, figures) # Sample audio - tb_logger.tb_train_audios( - global_step, { - "train/audio": sample_wav}, c.audio["sample_rate"] - ) + tb_logger.tb_train_audios(global_step, {"train/audio": sample_wav}, c.audio["sample_rate"]) end_time = time.time() # print epoch stats @@ -277,36 +273,32 @@ def evaluate(model, criterion, ap, global_step, epoch): # print eval stats if c.print_eval: - c_logger.print_eval_step( - num_iter, loss_dict, keep_avg.avg_values) + c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values) if epoch % c.test_every_epochs == 0 and epoch != 0: # synthesize a full voice rand_idx = random.randrange(0, len(eval_data)) - wav_path = eval_data[rand_idx] if not isinstance( - eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0] + wav_path = eval_data[rand_idx] if not isinstance(eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0] wav = ap.load_wav(wav_path) ground_mel = ap.melspectrogram(wav) ground_mel = torch.FloatTensor(ground_mel) if use_cuda: ground_mel = ground_mel.cuda(non_blocking=True) - sample_wav = model.inference(ground_mel, - c.batched, - c.target_samples, - c.overlap_samples, - ) + sample_wav = model.inference( + ground_mel, + c.batched, + c.target_samples, + c.overlap_samples, + ) predict_mel = ap.melspectrogram(sample_wav) # Sample audio - tb_logger.tb_eval_audios( - global_step, { - "eval/audio": sample_wav}, c.audio["sample_rate"] - ) + tb_logger.tb_eval_audios(global_step, {"eval/audio": sample_wav}, c.audio["sample_rate"]) # compute spectrograms figures = { "eval/ground_truth": plot_spectrogram(ground_mel.T), - "eval/prediction": plot_spectrogram(predict_mel.T) + "eval/prediction": plot_spectrogram(predict_mel.T), } tb_logger.tb_eval_figures(global_step, figures) @@ -347,11 +339,9 @@ def main(args): # pylint: disable=redefined-outer-name print(f" > Loading wavs from: {c.data_path}") if c.feature_path is not None: print(f" > Loading features from: {c.feature_path}") - eval_data, train_data = load_wav_feat_data( - c.data_path, c.feature_path, c.eval_split_size) + eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size) else: - eval_data, train_data = load_wav_data( - c.data_path, c.eval_split_size) + eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size) # setup model model_wavernn = setup_generator(c) @@ -404,8 +394,7 @@ def main(args): # pylint: disable=redefined-outer-name model_dict = set_init_dict(model_dict, checkpoint["model"], c) model_wavernn.load_state_dict(model_dict) - print(" > Model restored from step %d" % - checkpoint["step"], flush=True) + print(" > Model restored from step %d" % checkpoint["step"], flush=True) args.restore_step = checkpoint["step"] else: args.restore_step = 0 @@ -418,24 +407,20 @@ def main(args): # pylint: disable=redefined-outer-name print(" > Model has {} parameters".format(num_parameters), flush=True) if args.restore_step == 0 or not args.best_path: - best_loss = float('inf') + best_loss = float("inf") print(" > Starting with inf best loss.") else: - print(" > Restoring best loss from " - f"{os.path.basename(args.best_path)} ...") - best_loss = torch.load(args.best_path, - map_location='cpu')['model_loss'] + print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...") + best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"] print(f" > Starting with loaded last best loss {best_loss}.") - keep_all_best = c.get('keep_all_best', False) - keep_after = c.get('keep_after', 10000) # void if keep_all_best False + keep_all_best = c.get("keep_all_best", False) + keep_after = c.get("keep_after", 10000) # void if keep_all_best False global_step = args.restore_step for epoch in range(0, c.epochs): c_logger.print_epoch_start(epoch, c.epochs) - _, global_step = train(model_wavernn, optimizer, - criterion, scheduler, scaler, ap, global_step, epoch) - eval_avg_loss_dict = evaluate( - model_wavernn, criterion, ap, global_step, epoch) + _, global_step = train(model_wavernn, optimizer, criterion, scheduler, scaler, ap, global_step, epoch) + eval_avg_loss_dict = evaluate(model_wavernn, criterion, ap, global_step, epoch) c_logger.print_epoch_end(epoch, eval_avg_loss_dict) target_loss = eval_avg_loss_dict["avg_model_loss"] best_loss = save_best_model( @@ -453,14 +438,13 @@ def main(args): # pylint: disable=redefined-outer-name keep_all_best=keep_all_best, keep_after=keep_after, model_losses=eval_avg_loss_dict, - scaler=scaler.state_dict() if c.mixed_precision else None + scaler=scaler.state_dict() if c.mixed_precision else None, ) if __name__ == "__main__": args = parse_arguments(sys.argv) - c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args( - args, model_class='vocoder') + c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="vocoder") try: main(args) diff --git a/TTS/bin/tune_wavegrad.py b/TTS/bin/tune_wavegrad.py index 436a2764..d0f64214 100644 --- a/TTS/bin/tune_wavegrad.py +++ b/TTS/bin/tune_wavegrad.py @@ -13,14 +13,21 @@ from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset from TTS.vocoder.utils.generic_utils import setup_generator parser = argparse.ArgumentParser() -parser.add_argument('--model_path', type=str, help='Path to model checkpoint.') -parser.add_argument('--config_path', type=str, help='Path to model config file.') -parser.add_argument('--data_path', type=str, help='Path to data directory.') -parser.add_argument('--output_path', type=str, help='path for output file including file name and extension.') -parser.add_argument('--num_iter', type=int, help='Number of model inference iterations that you like to optimize noise schedule for.') -parser.add_argument('--use_cuda', type=bool, help='enable/disable CUDA.') -parser.add_argument('--num_samples', type=int, default=1, help='Number of datasamples used for inference.') -parser.add_argument('--search_depth', type=int, default=3, help='Search granularity. Increasing this increases the run-time exponentially.') +parser.add_argument("--model_path", type=str, help="Path to model checkpoint.") +parser.add_argument("--config_path", type=str, help="Path to model config file.") +parser.add_argument("--data_path", type=str, help="Path to data directory.") +parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.") +parser.add_argument( + "--num_iter", type=int, help="Number of model inference iterations that you like to optimize noise schedule for." +) +parser.add_argument("--use_cuda", type=bool, help="enable/disable CUDA.") +parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.") +parser.add_argument( + "--search_depth", + type=int, + default=3, + help="Search granularity. Increasing this increases the run-time exponentially.", +) # load config args = parser.parse_args() @@ -31,18 +38,20 @@ ap = AudioProcessor(**config.audio) # load dataset _, train_data = load_wav_data(args.data_path, 0) -train_data = train_data[:args.num_samples] -dataset = WaveGradDataset(ap=ap, - items=train_data, - seq_len=-1, - hop_len=ap.hop_length, - pad_short=config.pad_short, - conv_pad=config.conv_pad, - is_training=True, - return_segments=False, - use_noise_augment=False, - use_cache=False, - verbose=True) +train_data = train_data[: args.num_samples] +dataset = WaveGradDataset( + ap=ap, + items=train_data, + seq_len=-1, + hop_len=ap.hop_length, + pad_short=config.pad_short, + conv_pad=config.conv_pad, + is_training=True, + return_segments=False, + use_noise_augment=False, + use_cache=False, + verbose=True, +) loader = DataLoader( dataset, batch_size=1, @@ -50,7 +59,8 @@ loader = DataLoader( collate_fn=dataset.collate_full_clips, drop_last=False, num_workers=config.num_loader_workers, - pin_memory=False) + pin_memory=False, +) # setup the model model = setup_generator(config) @@ -61,9 +71,9 @@ if args.use_cuda: base_values = sorted(10 * np.random.uniform(size=args.search_depth)) print(base_values) exponents = 10 ** np.linspace(-6, -1, num=args.num_iter) -best_error = float('inf') +best_error = float("inf") best_schedule = None -total_search_iter = len(base_values)**args.num_iter +total_search_iter = len(base_values) ** args.num_iter for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter): beta = exponents * base model.compute_noise_level(beta) @@ -84,6 +94,6 @@ for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=tot mse = torch.sum((mel - mel_hat) ** 2).mean() if mse.item() < best_error: best_error = mse.item() - best_schedule = {'beta': beta} + best_schedule = {"beta": beta} print(f" > Found a better schedule. - MSE: {mse.item()}") np.save(args.output_path, best_schedule) diff --git a/TTS/server/server.py b/TTS/server/server.py index 05960f88..106c282c 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -13,23 +13,40 @@ from TTS.utils.io import load_config def create_argparser(): def convert_boolean(x): - return x.lower() in ['true', '1', 'yes'] + return x.lower() in ["true", "1", "yes"] parser = argparse.ArgumentParser() - parser.add_argument('--list_models', type=convert_boolean, nargs='?', const=True, default=False, help='list available pre-trained tts and vocoder models.') - parser.add_argument('--model_name', type=str, default="tts_models/en/ljspeech/speedy-speech-wn", help='name of one of the released tts models.') - parser.add_argument('--vocoder_name', type=str, default=None, help='name of one of the released vocoder models.') - parser.add_argument('--tts_checkpoint', type=str, help='path to custom tts checkpoint file') - parser.add_argument('--tts_config', type=str, help='path to custom tts config.json file') - parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model') - parser.add_argument('--vocoder_config', type=str, default=None, help='path to vocoder config file.') - parser.add_argument('--vocoder_checkpoint', type=str, default=None, help='path to vocoder checkpoint file.') - parser.add_argument('--port', type=int, default=5002, help='port to listen on.') - parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.') - parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.') - parser.add_argument('--show_details', type=convert_boolean, default=False, help='Generate model detail page.') + parser.add_argument( + "--list_models", + type=convert_boolean, + nargs="?", + const=True, + default=False, + help="list available pre-trained tts and vocoder models.", + ) + parser.add_argument( + "--model_name", + type=str, + default="tts_models/en/ljspeech/speedy-speech-wn", + help="name of one of the released tts models.", + ) + parser.add_argument("--vocoder_name", type=str, default=None, help="name of one of the released vocoder models.") + parser.add_argument("--tts_checkpoint", type=str, help="path to custom tts checkpoint file") + parser.add_argument("--tts_config", type=str, help="path to custom tts config.json file") + parser.add_argument( + "--tts_speakers", + type=str, + help="path to JSON file containing speaker ids, if speaker ids are used in the model", + ) + parser.add_argument("--vocoder_config", type=str, default=None, help="path to vocoder config file.") + parser.add_argument("--vocoder_checkpoint", type=str, default=None, help="path to vocoder checkpoint file.") + parser.add_argument("--port", type=int, default=5002, help="port to listen on.") + parser.add_argument("--use_cuda", type=convert_boolean, default=False, help="true to use CUDA.") + parser.add_argument("--debug", type=convert_boolean, default=False, help="true to enable Flask debug mode.") + parser.add_argument("--show_details", type=convert_boolean, default=False, help="Generate model detail page.") return parser + # parse the args args = create_argparser().parse_args() @@ -43,7 +60,7 @@ if args.list_models: # update in-use models to the specified released models. if args.model_name is not None: tts_checkpoint_file, tts_config_file, tts_json_dict = manager.download_model(args.model_name) - args.vocoder_name = tts_json_dict['default_vocoder'] if args.vocoder_name is None else args.vocoder_name + args.vocoder_name = tts_json_dict["default_vocoder"] if args.vocoder_name is None else args.vocoder_name if args.vocoder_name is not None: vocoder_checkpoint_file, vocoder_config_file, vocoder_json_dict = manager.download_model(args.vocoder_name) @@ -59,16 +76,19 @@ if not args.vocoder_checkpoint and os.path.isfile(vocoder_checkpoint_file): if not args.vocoder_config and os.path.isfile(vocoder_config_file): args.vocoder_config = vocoder_config_file -synthesizer = Synthesizer(args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda) +synthesizer = Synthesizer( + args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda +) app = Flask(__name__) -@app.route('/') +@app.route("/") def index(): - return render_template('index.html', show_details=args.show_details) + return render_template("index.html", show_details=args.show_details) -@app.route('/details') + +@app.route("/details") def details(): model_config = load_config(args.tts_config) if args.vocoder_config is not None and os.path.isfile(args.vocoder_config): @@ -76,26 +96,28 @@ def details(): else: vocoder_config = None - return render_template('details.html', - show_details=args.show_details - , model_config=model_config - , vocoder_config=vocoder_config - , args=args.__dict__ - ) + return render_template( + "details.html", + show_details=args.show_details, + model_config=model_config, + vocoder_config=vocoder_config, + args=args.__dict__, + ) -@app.route('/api/tts', methods=['GET']) + +@app.route("/api/tts", methods=["GET"]) def tts(): - text = request.args.get('text') + text = request.args.get("text") print(" > Model input: {}".format(text)) wavs = synthesizer.tts(text) out = io.BytesIO() synthesizer.save_wav(wavs, out) - return send_file(out, mimetype='audio/wav') + return send_file(out, mimetype="audio/wav") def main(): - app.run(debug=args.debug, host='0.0.0.0', port=args.port) + app.run(debug=args.debug, host="0.0.0.0", port=args.port) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/TTS/speaker_encoder/dataset.py b/TTS/speaker_encoder/dataset.py index 748f5136..38d8b5f9 100644 --- a/TTS/speaker_encoder/dataset.py +++ b/TTS/speaker_encoder/dataset.py @@ -7,9 +7,19 @@ from torch.utils.data import Dataset class MyDataset(Dataset): - def __init__(self, ap, meta_data, voice_len=1.6, num_speakers_in_batch=64, - storage_size=1, sample_from_storage_p=0.5, additive_noise=0, - num_utter_per_speaker=10, skip_speakers=False, verbose=False): + def __init__( + self, + ap, + meta_data, + voice_len=1.6, + num_speakers_in_batch=64, + storage_size=1, + sample_from_storage_p=0.5, + additive_noise=0, + num_utter_per_speaker=10, + skip_speakers=False, + verbose=False, + ): """ Args: ap (TTS.tts.utils.AudioProcessor): audio processor object. @@ -28,7 +38,7 @@ class MyDataset(Dataset): self.ap = ap self.verbose = verbose self.__parse_items() - self.storage = queue.Queue(maxsize=storage_size*num_speakers_in_batch) + self.storage = queue.Queue(maxsize=storage_size * num_speakers_in_batch) self.sample_from_storage_p = float(sample_from_storage_p) self.additive_noise = float(additive_noise) if self.verbose: @@ -69,11 +79,14 @@ class MyDataset(Dataset): if speaker_ in self.speaker_to_utters.keys(): self.speaker_to_utters[speaker_].append(path_) else: - self.speaker_to_utters[speaker_] = [path_, ] + self.speaker_to_utters[speaker_] = [ + path_, + ] if self.skip_speakers: - self.speaker_to_utters = {k: v for (k, v) in self.speaker_to_utters.items() if - len(v) >= self.num_utter_per_speaker} + self.speaker_to_utters = { + k: v for (k, v) in self.speaker_to_utters.items() if len(v) >= self.num_utter_per_speaker + } self.speakers = [k for (k, v) in self.speaker_to_utters.items()] @@ -100,13 +113,9 @@ class MyDataset(Dataset): def __sample_speaker(self): speaker = random.sample(self.speakers, 1)[0] if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]): - utters = random.choices( - self.speaker_to_utters[speaker], k=self.num_utter_per_speaker - ) + utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker) else: - utters = random.sample( - self.speaker_to_utters[speaker], self.num_utter_per_speaker - ) + utters = random.sample(self.speaker_to_utters[speaker], self.num_utter_per_speaker) return speaker, utters def __sample_speaker_utterances(self, speaker): @@ -160,7 +169,9 @@ class MyDataset(Dataset): # get a random subset of each of the wavs and convert to MFCC. offsets_ = [random.randint(0, wav.shape[0] - self.seq_len) for wav in wavs_] - mels_ = [self.ap.melspectrogram(wavs_[i][offsets_[i]: offsets_[i] + self.seq_len]) for i in range(len(wavs_))] + mels_ = [ + self.ap.melspectrogram(wavs_[i][offsets_[i] : offsets_[i] + self.seq_len]) for i in range(len(wavs_)) + ] feats_ = [torch.FloatTensor(mel) for mel in mels_] labels.append(labels_) diff --git a/TTS/speaker_encoder/losses.py b/TTS/speaker_encoder/losses.py index fc085674..69264ab4 100644 --- a/TTS/speaker_encoder/losses.py +++ b/TTS/speaker_encoder/losses.py @@ -16,14 +16,14 @@ class GE2ELoss(nn.Module): - init_w (float): defines the initial value of w in Equation (5) of [1] - init_b (float): definies the initial value of b in Equation (5) of [1] """ - super(GE2ELoss, self).__init__() + super().__init__() # pylint: disable=E1102 self.w = nn.Parameter(torch.tensor(init_w)) # pylint: disable=E1102 self.b = nn.Parameter(torch.tensor(init_b)) self.loss_method = loss_method - print(' > Initialised Generalized End-to-End loss') + print(" > Initialised Generalized End-to-End loss") assert self.loss_method in ["softmax", "contrast"] @@ -55,9 +55,7 @@ class GE2ELoss(nn.Module): for spkr_idx, speaker in enumerate(dvecs): cs_row = [] for utt_idx, utterance in enumerate(speaker): - new_centroids = self.calc_new_centroids( - dvecs, centroids, spkr_idx, utt_idx - ) + new_centroids = self.calc_new_centroids(dvecs, centroids, spkr_idx, utt_idx) # vector based cosine similarity for speed cs_row.append( torch.clamp( @@ -99,14 +97,8 @@ class GE2ELoss(nn.Module): L_row = [] for i in range(M): centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i]) - excl_centroids_sigmoids = torch.cat( - (centroids_sigmoids[:j], centroids_sigmoids[j + 1 :]) - ) - L_row.append( - 1.0 - - torch.sigmoid(cos_sim_matrix[j, i, j]) - + torch.max(excl_centroids_sigmoids) - ) + excl_centroids_sigmoids = torch.cat((centroids_sigmoids[:j], centroids_sigmoids[j + 1 :])) + L_row.append(1.0 - torch.sigmoid(cos_sim_matrix[j, i, j]) + torch.max(excl_centroids_sigmoids)) L_row = torch.stack(L_row) L.append(L_row) return torch.stack(L) @@ -122,6 +114,7 @@ class GE2ELoss(nn.Module): L = self.embed_loss(dvecs, cos_sim_matrix) return L.mean() + # adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py class AngleProtoLoss(nn.Module): """ @@ -134,15 +127,16 @@ class AngleProtoLoss(nn.Module): - init_w (float): defines the initial value of w - init_b (float): definies the initial value of b """ + def __init__(self, init_w=10.0, init_b=-5.0): - super(AngleProtoLoss, self).__init__() + super().__init__() # pylint: disable=E1102 self.w = nn.Parameter(torch.tensor(init_w)) # pylint: disable=E1102 self.b = nn.Parameter(torch.tensor(init_b)) self.criterion = torch.nn.CrossEntropyLoss() - print(' > Initialised Angular Prototypical loss') + print(" > Initialised Angular Prototypical loss") def forward(self, x): """ @@ -152,7 +146,10 @@ class AngleProtoLoss(nn.Module): out_positive = x[:, 0, :] num_speakers = out_anchor.size()[0] - cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1, -1, num_speakers), out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2)) + cos_sim_matrix = F.cosine_similarity( + out_positive.unsqueeze(-1).expand(-1, -1, num_speakers), + out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2), + ) torch.clamp(self.w, 1e-6) cos_sim_matrix = cos_sim_matrix * self.w + self.b label = torch.arange(num_speakers).to(cos_sim_matrix.device) diff --git a/TTS/speaker_encoder/model.py b/TTS/speaker_encoder/model.py index 322ee42f..7a3dc09c 100644 --- a/TTS/speaker_encoder/model.py +++ b/TTS/speaker_encoder/model.py @@ -16,19 +16,19 @@ class LSTMWithProjection(nn.Module): o, (_, _) = self.lstm(x) return self.linear(o) + class LSTMWithoutProjection(nn.Module): def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers): super().__init__() - self.lstm = nn.LSTM(input_size=input_dim, - hidden_size=lstm_dim, - num_layers=num_lstm_layers, - batch_first=True) + self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True) self.linear = nn.Linear(lstm_dim, proj_dim, bias=True) self.relu = nn.ReLU() + def forward(self, x): _, (hidden, _) = self.lstm(x) return self.relu(self.linear(hidden[-1])) + class SpeakerEncoder(nn.Module): def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True): super().__init__() @@ -106,7 +106,5 @@ class SpeakerEncoder(nn.Module): if embed is None: embed = self.inference(frames) else: - embed[cur_iter <= num_iters, :] += self.inference( - frames[cur_iter <= num_iters, :, :] - ) + embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :]) return embed / num_iters diff --git a/TTS/speaker_encoder/utils/generic_utils.py b/TTS/speaker_encoder/utils/generic_utils.py index 47bf79cc..38f9870d 100644 --- a/TTS/speaker_encoder/utils/generic_utils.py +++ b/TTS/speaker_encoder/utils/generic_utils.py @@ -9,108 +9,114 @@ from TTS.utils.generic_utils import check_argument def to_camel(text): text = text.capitalize() - return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) + return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) def setup_model(c): - model = SpeakerEncoder(c.model['input_dim'], c.model['proj_dim'], - c.model['lstm_dim'], c.model['num_lstm_layers']) + model = SpeakerEncoder(c.model["input_dim"], c.model["proj_dim"], c.model["lstm_dim"], c.model["num_lstm_layers"]) return model -def save_checkpoint(model, optimizer, model_loss, out_path, - current_step, epoch): - checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step) +def save_checkpoint(model, optimizer, model_loss, out_path, current_step, epoch): + checkpoint_path = "checkpoint_{}.pth.tar".format(current_step) checkpoint_path = os.path.join(out_path, checkpoint_path) print(" | | > Checkpoint saving : {}".format(checkpoint_path)) new_state_dict = model.state_dict() state = { - 'model': new_state_dict, - 'optimizer': optimizer.state_dict() if optimizer is not None else None, - 'step': current_step, - 'epoch': epoch, - 'loss': model_loss, - 'date': datetime.date.today().strftime("%B %d, %Y"), + "model": new_state_dict, + "optimizer": optimizer.state_dict() if optimizer is not None else None, + "step": current_step, + "epoch": epoch, + "loss": model_loss, + "date": datetime.date.today().strftime("%B %d, %Y"), } torch.save(state, checkpoint_path) -def save_best_model(model, optimizer, model_loss, best_loss, out_path, - current_step): +def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step): if model_loss < best_loss: new_state_dict = model.state_dict() state = { - 'model': new_state_dict, - 'optimizer': optimizer.state_dict(), - 'step': current_step, - 'loss': model_loss, - 'date': datetime.date.today().strftime("%B %d, %Y"), + "model": new_state_dict, + "optimizer": optimizer.state_dict(), + "step": current_step, + "loss": model_loss, + "date": datetime.date.today().strftime("%B %d, %Y"), } best_loss = model_loss - bestmodel_path = 'best_model.pth.tar' + bestmodel_path = "best_model.pth.tar" bestmodel_path = os.path.join(out_path, bestmodel_path) - print("\n > BEST MODEL ({0:.5f}) : {1:}".format( - model_loss, bestmodel_path)) + print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path)) torch.save(state, bestmodel_path) return best_loss def check_config_speaker_encoder(c): """Check the config.json file of the speaker encoder""" - check_argument('run_name', c, restricted=True, val_type=str) - check_argument('run_description', c, val_type=str) + check_argument("run_name", c, restricted=True, val_type=str) + check_argument("run_description", c, val_type=str) # audio processing parameters - check_argument('audio', c, restricted=True, val_type=dict) - check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056) - check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058) - check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000) - check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length') - check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length') - check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1) - check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10) - check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000) - check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5) - check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000) + check_argument("audio", c, restricted=True, val_type=dict) + check_argument("num_mels", c["audio"], restricted=True, val_type=int, min_val=10, max_val=2056) + check_argument("fft_size", c["audio"], restricted=True, val_type=int, min_val=128, max_val=4058) + check_argument("sample_rate", c["audio"], restricted=True, val_type=int, min_val=512, max_val=100000) + check_argument( + "frame_length_ms", + c["audio"], + restricted=True, + val_type=float, + min_val=10, + max_val=1000, + alternative="win_length", + ) + check_argument( + "frame_shift_ms", c["audio"], restricted=True, val_type=float, min_val=1, max_val=1000, alternative="hop_length" + ) + check_argument("preemphasis", c["audio"], restricted=True, val_type=float, min_val=0, max_val=1) + check_argument("min_level_db", c["audio"], restricted=True, val_type=int, min_val=-1000, max_val=10) + check_argument("ref_level_db", c["audio"], restricted=True, val_type=int, min_val=0, max_val=1000) + check_argument("power", c["audio"], restricted=True, val_type=float, min_val=1, max_val=5) + check_argument("griffin_lim_iters", c["audio"], restricted=True, val_type=int, min_val=10, max_val=1000) # training parameters - check_argument('loss', c, enum_list=['ge2e', 'angleproto'], restricted=True, val_type=str) - check_argument('grad_clip', c, restricted=True, val_type=float) - check_argument('epochs', c, restricted=True, val_type=int, min_val=1) - check_argument('lr', c, restricted=True, val_type=float, min_val=0) - check_argument('lr_decay', c, restricted=True, val_type=bool) - check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0) - check_argument('tb_model_param_stats', c, restricted=True, val_type=bool) - check_argument('num_speakers_in_batch', c, restricted=True, val_type=int) - check_argument('num_loader_workers', c, restricted=True, val_type=int) - check_argument('wd', c, restricted=True, val_type=float, min_val=0.0, max_val=1.0) + check_argument("loss", c, enum_list=["ge2e", "angleproto"], restricted=True, val_type=str) + check_argument("grad_clip", c, restricted=True, val_type=float) + check_argument("epochs", c, restricted=True, val_type=int, min_val=1) + check_argument("lr", c, restricted=True, val_type=float, min_val=0) + check_argument("lr_decay", c, restricted=True, val_type=bool) + check_argument("warmup_steps", c, restricted=True, val_type=int, min_val=0) + check_argument("tb_model_param_stats", c, restricted=True, val_type=bool) + check_argument("num_speakers_in_batch", c, restricted=True, val_type=int) + check_argument("num_loader_workers", c, restricted=True, val_type=int) + check_argument("wd", c, restricted=True, val_type=float, min_val=0.0, max_val=1.0) # checkpoint and output parameters - check_argument('steps_plot_stats', c, restricted=True, val_type=int) - check_argument('checkpoint', c, restricted=True, val_type=bool) - check_argument('save_step', c, restricted=True, val_type=int) - check_argument('print_step', c, restricted=True, val_type=int) - check_argument('output_path', c, restricted=True, val_type=str) + check_argument("steps_plot_stats", c, restricted=True, val_type=int) + check_argument("checkpoint", c, restricted=True, val_type=bool) + check_argument("save_step", c, restricted=True, val_type=int) + check_argument("print_step", c, restricted=True, val_type=int) + check_argument("output_path", c, restricted=True, val_type=str) # model parameters - check_argument('model', c, restricted=True, val_type=dict) - check_argument('input_dim', c['model'], restricted=True, val_type=int) - check_argument('proj_dim', c['model'], restricted=True, val_type=int) - check_argument('lstm_dim', c['model'], restricted=True, val_type=int) - check_argument('num_lstm_layers', c['model'], restricted=True, val_type=int) - check_argument('use_lstm_with_projection', c['model'], restricted=True, val_type=bool) + check_argument("model", c, restricted=True, val_type=dict) + check_argument("input_dim", c["model"], restricted=True, val_type=int) + check_argument("proj_dim", c["model"], restricted=True, val_type=int) + check_argument("lstm_dim", c["model"], restricted=True, val_type=int) + check_argument("num_lstm_layers", c["model"], restricted=True, val_type=int) + check_argument("use_lstm_with_projection", c["model"], restricted=True, val_type=bool) # in-memory storage parameters - check_argument('storage', c, restricted=True, val_type=dict) - check_argument('sample_from_storage_p', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0) - check_argument('storage_size', c['storage'], restricted=True, val_type=int, min_val=1, max_val=100) - check_argument('additive_noise', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0) + check_argument("storage", c, restricted=True, val_type=dict) + check_argument("sample_from_storage_p", c["storage"], restricted=True, val_type=float, min_val=0.0, max_val=1.0) + check_argument("storage_size", c["storage"], restricted=True, val_type=int, min_val=1, max_val=100) + check_argument("additive_noise", c["storage"], restricted=True, val_type=float, min_val=0.0, max_val=1.0) # datasets - checking only the first entry - check_argument('datasets', c, restricted=True, val_type=list) - for dataset_entry in c['datasets']: - check_argument('name', dataset_entry, restricted=True, val_type=str) - check_argument('path', dataset_entry, restricted=True, val_type=str) - check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list]) - check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str) + check_argument("datasets", c, restricted=True, val_type=list) + for dataset_entry in c["datasets"]: + check_argument("name", dataset_entry, restricted=True, val_type=str) + check_argument("path", dataset_entry, restricted=True, val_type=str) + check_argument("meta_file_train", dataset_entry, restricted=True, val_type=[str, list]) + check_argument("meta_file_val", dataset_entry, restricted=True, val_type=str) diff --git a/TTS/speaker_encoder/utils/prepare_voxceleb.py b/TTS/speaker_encoder/utils/prepare_voxceleb.py index 758e1cb3..58ff9dad 100644 --- a/TTS/speaker_encoder/utils/prepare_voxceleb.py +++ b/TTS/speaker_encoder/utils/prepare_voxceleb.py @@ -17,7 +17,7 @@ # Only support eager mode and TF>=2.0.0 # pylint: disable=no-member, invalid-name, relative-beyond-top-level # pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes -''' voxceleb 1 & 2 ''' +""" voxceleb 1 & 2 """ import os import sys @@ -32,40 +32,38 @@ import soundfile as sf gfile = tf.compat.v1.gfile SUBSETS = { - "vox1_dev_wav": - ["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa", - "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab", - "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac", - "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad"], - "vox1_test_wav": - ["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"], - "vox2_dev_aac": - ["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa", - "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab", - "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac", - "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad", - "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae", - "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf", - "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag", - "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah"], - "vox2_test_aac": - ["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"] + "vox1_dev_wav": [ + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad", + ], + "vox1_test_wav": ["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"], + "vox2_dev_aac": [ + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag", + "http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah", + ], + "vox2_test_aac": ["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"], } MD5SUM = { "vox1_dev_wav": "ae63e55b951748cc486645f532ba230b", "vox2_dev_aac": "bbc063c46078a602ca71605645c2a402", "vox1_test_wav": "185fdc63c3c739954633d50379a3d102", - "vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312" + "vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312", } -USER = { - "user": "", - "password": "" -} +USER = {"user": "", "password": ""} speaker_id_dict = {} + def download_and_extract(directory, subset, urls): """Download and extract the given split of dataset. @@ -83,31 +81,30 @@ def download_and_extract(directory, subset, urls): if os.path.exists(zip_filepath): continue logging.info("Downloading %s to %s" % (url, zip_filepath)) - subprocess.call('wget %s --user %s --password %s -O %s' % - (url, USER["user"], USER["password"], zip_filepath), shell=True) + subprocess.call( + "wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath), + shell=True, + ) statinfo = os.stat(zip_filepath) - logging.info( - "Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size) - ) + logging.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)) # concatenate all parts into zip files if ".zip" not in zip_filepath: zip_filepath = "_".join(zip_filepath.split("_")[:-1]) - subprocess.call('cat %s* > %s.zip' % - (zip_filepath, zip_filepath), shell=True) + subprocess.call("cat %s* > %s.zip" % (zip_filepath, zip_filepath), shell=True) zip_filepath += ".zip" extract_path = zip_filepath.strip(".zip") # check zip file md5sum - md5 = hashlib.md5(open(zip_filepath, 'rb').read()).hexdigest() + md5 = hashlib.md5(open(zip_filepath, "rb").read()).hexdigest() if md5 != MD5SUM[subset]: raise ValueError("md5sum of %s mismatch" % zip_filepath) with zipfile.ZipFile(zip_filepath, "r") as zfile: zfile.extractall(directory) extract_path_ori = os.path.join(directory, zfile.infolist()[0].filename) - subprocess.call('mv %s %s' % (extract_path_ori, extract_path), shell=True) + subprocess.call("mv %s %s" % (extract_path_ori, extract_path), shell=True) finally: # gfile.Remove(zip_filepath) pass @@ -148,8 +145,7 @@ def decode_aac_with_ffmpeg(aac_file, wav_file): return True -def convert_audio_and_make_label(input_dir, subset, - output_dir, output_file): +def convert_audio_and_make_label(input_dir, subset, output_dir, output_file): """Optionally convert AAC to WAV and make speaker labels. Args: input_dir: the directory which holds the input dataset. @@ -167,7 +163,7 @@ def convert_audio_and_make_label(input_dir, subset, for filename in filenames: name, ext = os.path.splitext(filename) if ext.lower() == ".wav": - _, ext2 = (os.path.splitext(name)) + _, ext2 = os.path.splitext(name) if ext2: continue wav_file = os.path.join(root, filename) @@ -186,15 +182,12 @@ def convert_audio_and_make_label(input_dir, subset, speaker_id_dict[speaker_name] = num # wav_filesize = os.path.getsize(wav_file) wav_length = len(sf.read(wav_file)[0]) - files.append( - (os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name) - ) + files.append((os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name)) # Write to CSV file which contains four columns: # "wav_filename", "wav_length_ms", "speaker_id", "speaker_name". csv_file_path = os.path.join(output_dir, output_file) - df = pandas.DataFrame( - data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"]) + df = pandas.DataFrame(data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"]) df.to_csv(csv_file_path, index=False, sep="\t") logging.info("Successfully generated csv file {}".format(csv_file_path)) @@ -205,19 +198,14 @@ def processor(directory, subset, force_process): if subset not in urls: raise ValueError(subset, "is not in voxceleb") - subset_csv = os.path.join(directory, subset + '.csv') + subset_csv = os.path.join(directory, subset + ".csv") if not force_process and os.path.exists(subset_csv): return subset_csv logging.info("Downloading and process the voxceleb in %s", directory) logging.info("Preparing subset %s", subset) download_and_extract(directory, subset, urls[subset]) - convert_audio_and_make_label( - directory, - subset, - directory, - subset + ".csv" - ) + convert_audio_and_make_label(directory, subset, directory, subset + ".csv") logging.info("Finished downloading and processing") return subset_csv diff --git a/TTS/tts/datasets/TTSDataset.py b/TTS/tts/datasets/TTSDataset.py index eaabb42b..3c791625 100644 --- a/TTS/tts/datasets/TTSDataset.py +++ b/TTS/tts/datasets/TTSDataset.py @@ -7,31 +7,31 @@ import numpy as np import torch import tqdm from torch.utils.data import Dataset -from TTS.tts.utils.data import (prepare_data, prepare_stop_target, - prepare_tensor) -from TTS.tts.utils.text import (pad_with_eos_bos, phoneme_to_sequence, - text_to_sequence) +from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor +from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence class MyDataset(Dataset): - def __init__(self, - outputs_per_step, - text_cleaner, - compute_linear_spec, - ap, - meta_data, - tp=None, - add_blank=False, - batch_group_size=0, - min_seq_len=0, - max_seq_len=float("inf"), - use_phonemes=True, - phoneme_cache_path=None, - phoneme_language="en-us", - enable_eos_bos=False, - speaker_mapping=None, - use_noise_augment=False, - verbose=False): + def __init__( + self, + outputs_per_step, + text_cleaner, + compute_linear_spec, + ap, + meta_data, + tp=None, + add_blank=False, + batch_group_size=0, + min_seq_len=0, + max_seq_len=float("inf"), + use_phonemes=True, + phoneme_cache_path=None, + phoneme_language="en-us", + enable_eos_bos=False, + speaker_mapping=None, + use_noise_augment=False, + verbose=False, + ): """ Args: outputs_per_step (int): number of time frames predicted per step. @@ -53,7 +53,7 @@ class MyDataset(Dataset): use_noise_augment (bool): enable adding random noise to wav for augmentation. verbose (bool): print diagnostic information. """ - super(MyDataset, self).__init__() + super().__init__() self.batch_group_size = batch_group_size self.items = meta_data self.outputs_per_step = outputs_per_step @@ -88,45 +88,42 @@ class MyDataset(Dataset): @staticmethod def load_np(filename): - data = np.load(filename).astype('float32') + data = np.load(filename).astype("float32") return data @staticmethod - def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, - language, tp, add_blank): + def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, tp, add_blank): """generate a phoneme sequence from text. since the usage is for subsequent caching, we never add bos and eos chars here. Instead we add those dynamically later; based on the config option.""" - phonemes = phoneme_to_sequence(text, [cleaners], - language=language, - enable_eos_bos=False, - tp=tp, - add_blank=add_blank) + phonemes = phoneme_to_sequence( + text, [cleaners], language=language, enable_eos_bos=False, tp=tp, add_blank=add_blank + ) phonemes = np.asarray(phonemes, dtype=np.int32) np.save(cache_path, phonemes) return phonemes @staticmethod - def _load_or_generate_phoneme_sequence(wav_file, text, phoneme_cache_path, - enable_eos_bos, cleaners, language, - tp, add_blank): + def _load_or_generate_phoneme_sequence( + wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, tp, add_blank + ): file_name = os.path.splitext(os.path.basename(wav_file))[0] # different names for normal phonemes and with blank chars. - file_name_ext = '_blanked_phoneme.npy' if add_blank else '_phoneme.npy' - cache_path = os.path.join(phoneme_cache_path, - file_name + file_name_ext) + file_name_ext = "_blanked_phoneme.npy" if add_blank else "_phoneme.npy" + cache_path = os.path.join(phoneme_cache_path, file_name + file_name_ext) try: phonemes = np.load(cache_path) except FileNotFoundError: phonemes = MyDataset._generate_and_cache_phoneme_sequence( - text, cache_path, cleaners, language, tp, add_blank) + text, cache_path, cleaners, language, tp, add_blank + ) except (ValueError, IOError): - print(" [!] failed loading phonemes for {}. " - "Recomputing.".format(wav_file)) + print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file)) phonemes = MyDataset._generate_and_cache_phoneme_sequence( - text, cache_path, cleaners, language, tp, add_blank) + text, cache_path, cleaners, language, tp, add_blank + ) if enable_eos_bos: phonemes = pad_with_eos_bos(phonemes, tp=tp) phonemes = np.asarray(phonemes, dtype=np.int32) @@ -150,15 +147,20 @@ class MyDataset(Dataset): if not self.input_seq_computed: if self.use_phonemes: text = self._load_or_generate_phoneme_sequence( - wav_file, text, self.phoneme_cache_path, - self.enable_eos_bos, self.cleaners, self.phoneme_language, - self.tp, self.add_blank) + wav_file, + text, + self.phoneme_cache_path, + self.enable_eos_bos, + self.cleaners, + self.phoneme_language, + self.tp, + self.add_blank, + ) else: - text = np.asarray(text_to_sequence(text, [self.cleaners], - tp=self.tp, - add_blank=self.add_blank), - dtype=np.int32) + text = np.asarray( + text_to_sequence(text, [self.cleaners], tp=self.tp, add_blank=self.add_blank), dtype=np.int32 + ) assert text.size > 0, self.items[idx][1] assert wav.size > 0, self.items[idx][1] @@ -173,12 +175,12 @@ class MyDataset(Dataset): return self.load_data(100) sample = { - 'text': text, - 'wav': wav, - 'attn': attn, - 'item_idx': self.items[idx][1], - 'speaker_name': speaker_name, - 'wav_file_name': os.path.basename(wav_file) + "text": text, + "wav": wav, + "attn": attn, + "item_idx": self.items[idx][1], + "speaker_name": speaker_name, + "wav_file_name": os.path.basename(wav_file), } return sample @@ -187,8 +189,7 @@ class MyDataset(Dataset): item = args[0] func_args = args[1] text, wav_file, *_ = item - phonemes = MyDataset._load_or_generate_phoneme_sequence( - wav_file, text, *func_args) + phonemes = MyDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args) return phonemes def compute_input_seq(self, num_workers=0): @@ -199,17 +200,19 @@ class MyDataset(Dataset): print(" | > Computing input sequences ...") for idx, item in enumerate(tqdm.tqdm(self.items)): text, *_ = item - sequence = np.asarray(text_to_sequence( - text, [self.cleaners], - tp=self.tp, - add_blank=self.add_blank), - dtype=np.int32) + sequence = np.asarray( + text_to_sequence(text, [self.cleaners], tp=self.tp, add_blank=self.add_blank), dtype=np.int32 + ) self.items[idx][0] = sequence else: func_args = [ - self.phoneme_cache_path, self.enable_eos_bos, self.cleaners, - self.phoneme_language, self.tp, self.add_blank + self.phoneme_cache_path, + self.enable_eos_bos, + self.cleaners, + self.phoneme_language, + self.tp, + self.add_blank, ] if self.verbose: print(" | > Computing phonemes ...") @@ -220,10 +223,11 @@ class MyDataset(Dataset): else: with Pool(num_workers) as p: phonemes = list( - tqdm.tqdm(p.imap(MyDataset._phoneme_worker, - [[item, func_args] - for item in self.items]), - total=len(self.items))) + tqdm.tqdm( + p.imap(MyDataset._phoneme_worker, [[item, func_args] for item in self.items]), + total=len(self.items), + ) + ) for idx, p in enumerate(phonemes): self.items[idx][0] = p @@ -255,8 +259,10 @@ class MyDataset(Dataset): print(" | > Min length sequence: {}".format(np.min(lengths))) print(" | > Avg length sequence: {}".format(np.mean(lengths))) print( - " | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}" - .format(self.max_seq_len, self.min_seq_len, len(ignored))) + " | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}".format( + self.max_seq_len, self.min_seq_len, len(ignored) + ) + ) print(" | > Batch group size: {}.".format(self.batch_group_size)) def __len__(self): @@ -267,11 +273,11 @@ class MyDataset(Dataset): def collate_fn(self, batch): r""" - Perform preprocessing and create a final data batch: - 1. Sort batch instances by text-length - 2. Convert Audio signal to Spectrograms. - 3. PAD sequences wrt r. - 4. Load to Torch. + Perform preprocessing and create a final data batch: + 1. Sort batch instances by text-length + 2. Convert Audio signal to Spectrograms. + 3. PAD sequences wrt r. + 4. Load to Torch. """ # Puts each data field into a tensor with outer dimension batch size @@ -280,44 +286,29 @@ class MyDataset(Dataset): text_lenghts = np.array([len(d["text"]) for d in batch]) # sort items with text input length for RNN efficiency - text_lenghts, ids_sorted_decreasing = torch.sort( - torch.LongTensor(text_lenghts), dim=0, descending=True) + text_lenghts, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lenghts), dim=0, descending=True) - wav = [batch[idx]['wav'] for idx in ids_sorted_decreasing] - item_idxs = [ - batch[idx]['item_idx'] for idx in ids_sorted_decreasing - ] - text = [batch[idx]['text'] for idx in ids_sorted_decreasing] + wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing] + item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing] + text = [batch[idx]["text"] for idx in ids_sorted_decreasing] - speaker_name = [ - batch[idx]['speaker_name'] for idx in ids_sorted_decreasing - ] + speaker_name = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing] # get speaker embeddings if self.speaker_mapping is not None: - wav_files_names = [ - batch[idx]['wav_file_name'] - for idx in ids_sorted_decreasing - ] - speaker_embedding = [ - self.speaker_mapping[w]['embedding'] - for w in wav_files_names - ] + wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing] + speaker_embedding = [self.speaker_mapping[w]["embedding"] for w in wav_files_names] else: speaker_embedding = None # compute features - mel = [self.ap.melspectrogram(w).astype('float32') for w in wav] + mel = [self.ap.melspectrogram(w).astype("float32") for w in wav] mel_lengths = [m.shape[1] for m in mel] # compute 'stop token' targets - stop_targets = [ - np.array([0.] * (mel_len - 1) + [1.]) - for mel_len in mel_lengths - ] + stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths] # PAD stop targets - stop_targets = prepare_stop_target(stop_targets, - self.outputs_per_step) + stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step) # PAD sequences with longest instance in the batch text = prepare_data(text).astype(np.int32) @@ -340,9 +331,7 @@ class MyDataset(Dataset): # compute linear spectrogram if self.compute_linear_spec: - linear = [ - self.ap.spectrogram(w).astype('float32') for w in wav - ] + linear = [self.ap.spectrogram(w).astype("float32") for w in wav] linear = prepare_tensor(linear, self.outputs_per_step) linear = linear.transpose(0, 2, 1) assert mel.shape[1] == linear.shape[1] @@ -351,8 +340,8 @@ class MyDataset(Dataset): linear = None # collate attention alignments - if batch[0]['attn'] is not None: - attns = [batch[idx]['attn'].T for idx in ids_sorted_decreasing] + if batch[0]["attn"] is not None: + attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing] for idx, attn in enumerate(attns): pad2 = mel.shape[1] - attn.shape[1] pad1 = text.shape[1] - attn.shape[0] @@ -362,8 +351,24 @@ class MyDataset(Dataset): attns = torch.FloatTensor(attns).unsqueeze(1) else: attns = None - return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \ - stop_targets, item_idxs, speaker_embedding, attns + return ( + text, + text_lenghts, + speaker_name, + linear, + mel, + mel_lengths, + stop_targets, + item_idxs, + speaker_embedding, + attns, + ) - raise TypeError(("batch must contain tensors, numbers, dicts or lists;\ - found {}".format(type(batch[0])))) + raise TypeError( + ( + "batch must contain tensors, numbers, dicts or lists;\ + found {}".format( + type(batch[0]) + ) + ) + ) diff --git a/TTS/tts/datasets/preprocess.py b/TTS/tts/datasets/preprocess.py index 12148b1e..14c7d4c5 100644 --- a/TTS/tts/datasets/preprocess.py +++ b/TTS/tts/datasets/preprocess.py @@ -13,14 +13,15 @@ from TTS.tts.utils.generic_utils import split_dataset # UTILITIES #################### + def load_meta_data(datasets, eval_split=True): meta_data_train_all = [] meta_data_eval_all = [] if eval_split else None for dataset in datasets: - name = dataset['name'] - root_path = dataset['path'] - meta_file_train = dataset['meta_file_train'] - meta_file_val = dataset['meta_file_val'] + name = dataset["name"] + root_path = dataset["path"] + meta_file_train = dataset["meta_file_train"] + meta_file_val = dataset["meta_file_val"] # setup the right data processor preprocessor = get_preprocessor_by_name(name) # load train set @@ -35,8 +36,8 @@ def load_meta_data(datasets, eval_split=True): meta_data_eval_all += meta_data_eval meta_data_train_all += meta_data_train # load attention masks for duration predictor training - if 'meta_file_attn_mask' in dataset and dataset['meta_file_attn_mask'] is not None: - meta_data = dict(load_attention_mask_meta_data(dataset['meta_file_attn_mask'])) + if "meta_file_attn_mask" in dataset and dataset["meta_file_attn_mask"] is not None: + meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"])) for idx, ins in enumerate(meta_data_train_all): attn_file = meta_data[ins[1]].strip() meta_data_train_all[idx].append(attn_file) @@ -49,12 +50,12 @@ def load_meta_data(datasets, eval_split=True): def load_attention_mask_meta_data(metafile_path): """Load meta data file created by compute_attention_masks.py""" - with open(metafile_path, 'r') as f: + with open(metafile_path, "r") as f: lines = f.readlines() meta_data = [] for line in lines: - wav_file, attn_file = line.split('|') + wav_file, attn_file = line.split("|") meta_data.append([wav_file, attn_file]) return meta_data @@ -69,6 +70,7 @@ def get_preprocessor_by_name(name): # DATASETS ######################## + def tweb(root_path, meta_file): """Normalize TWEB dataset. https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset @@ -76,10 +78,10 @@ def tweb(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "tweb" - with open(txt_file, 'r') as ttf: + with open(txt_file, "r") as ttf: for line in ttf: - cols = line.split('\t') - wav_file = os.path.join(root_path, cols[0] + '.wav') + cols = line.split("\t") + wav_file = os.path.join(root_path, cols[0] + ".wav") text = cols[1] items.append([text, wav_file, speaker_name]) return items @@ -90,9 +92,9 @@ def mozilla(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "mozilla" - with open(txt_file, 'r') as ttf: + with open(txt_file, "r") as ttf: for line in ttf: - cols = line.split('|') + cols = line.split("|") wav_file = cols[1].strip() text = cols[0].strip() wav_file = os.path.join(root_path, "wavs", wav_file) @@ -105,9 +107,9 @@ def mozilla_de(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "mozilla" - with open(txt_file, 'r', encoding="ISO 8859-1") as ttf: + with open(txt_file, "r", encoding="ISO 8859-1") as ttf: for line in ttf: - cols = line.strip().split('|') + cols = line.strip().split("|") wav_file = cols[0].strip() text = cols[1].strip() folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL" @@ -118,8 +120,7 @@ def mozilla_de(root_path, meta_file): def mailabs(root_path, meta_files=None): """Normalizes M-AI-Labs meta data files to TTS format""" - speaker_regex = re.compile( - "by_book/(male|female)/(?P[^/]+)/") + speaker_regex = re.compile("by_book/(male|female)/(?P[^/]+)/") if meta_files is None: csv_files = glob(root_path + "/**/metadata.csv", recursive=True) else: @@ -135,21 +136,18 @@ def mailabs(root_path, meta_files=None): continue speaker_name = speaker_name_match.group("speaker_name") print(" | > {}".format(csv_file)) - with open(txt_file, 'r') as ttf: + with open(txt_file, "r") as ttf: for line in ttf: - cols = line.split('|') + cols = line.split("|") if meta_files is None: - wav_file = os.path.join(folder, 'wavs', cols[0] + '.wav') + wav_file = os.path.join(folder, "wavs", cols[0] + ".wav") else: - wav_file = os.path.join(root_path, - folder.replace("metadata.csv", ""), - 'wavs', cols[0] + '.wav') + wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav") if os.path.isfile(wav_file): text = cols[1].strip() items.append([text, wav_file, speaker_name]) else: - raise RuntimeError("> File %s does not exist!" % - (wav_file)) + raise RuntimeError("> File %s does not exist!" % (wav_file)) return items @@ -159,10 +157,10 @@ def ljspeech(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "ljspeech" - with open(txt_file, 'r', encoding="utf-8") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: - cols = line.split('|') - wav_file = os.path.join(root_path, 'wavs', cols[0] + '.wav') + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav") text = cols[1] items.append([text, wav_file, speaker_name]) return items @@ -171,15 +169,15 @@ def ljspeech(root_path, meta_file): def sam_accenture(root_path, meta_file): """Normalizes the sam-accenture meta data file to TTS format https://github.com/Sam-Accenture-Non-Binary-Voice/non-binary-voice-files""" - xml_file = os.path.join(root_path, 'voice_over_recordings', meta_file) + xml_file = os.path.join(root_path, "voice_over_recordings", meta_file) xml_root = ET.parse(xml_file).getroot() items = [] speaker_name = "sam_accenture" - for item in xml_root.findall('./fileid'): + for item in xml_root.findall("./fileid"): text = item.text - wav_file = os.path.join(root_path, 'vo_voice_quality_transformation', item.get('id')+'.wav') + wav_file = os.path.join(root_path, "vo_voice_quality_transformation", item.get("id") + ".wav") if not os.path.exists(wav_file): - print(f' [!] {wav_file} in metafile does not exist. Skipping...') + print(f" [!] {wav_file} in metafile does not exist. Skipping...") continue items.append([text, wav_file, speaker_name]) return items @@ -191,10 +189,10 @@ def ruslan(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "ljspeech" - with open(txt_file, 'r', encoding="utf-8") as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: - cols = line.split('|') - wav_file = os.path.join(root_path, 'RUSLAN', cols[0] + '.wav') + cols = line.split("|") + wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav") text = cols[1] items.append([text, wav_file, speaker_name]) return items @@ -205,9 +203,9 @@ def css10(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "ljspeech" - with open(txt_file, 'r') as ttf: + with open(txt_file, "r") as ttf: for line in ttf: - cols = line.split('|') + cols = line.split("|") wav_file = os.path.join(root_path, cols[0]) text = cols[1] items.append([text, wav_file, speaker_name]) @@ -219,10 +217,10 @@ def nancy(root_path, meta_file): txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "nancy" - with open(txt_file, 'r') as ttf: + with open(txt_file, "r") as ttf: for line in ttf: utt_id = line.split()[1] - text = line[line.find('"') + 1:line.rfind('"') - 1] + text = line[line.find('"') + 1 : line.rfind('"') - 1] wav_file = os.path.join(root_path, "wavn", utt_id + ".wav") items.append([text, wav_file, speaker_name]) return items @@ -232,7 +230,7 @@ def common_voice(root_path, meta_file): """Normalize the common voice meta data file to TTS format.""" txt_file = os.path.join(root_path, meta_file) items = [] - with open(txt_file, 'r') as ttf: + with open(txt_file, "r") as ttf: for line in ttf: if line.startswith("client_id"): continue @@ -240,7 +238,7 @@ def common_voice(root_path, meta_file): text = cols[2] speaker_name = cols[0] wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav")) - items.append([text, wav_file, 'MCV_' + speaker_name]) + items.append([text, wav_file, "MCV_" + speaker_name]) return items @@ -250,19 +248,18 @@ def libri_tts(root_path, meta_files=None): if meta_files is None: meta_files = glob(f"{root_path}/**/*trans.tsv", recursive=True) for meta_file in meta_files: - _meta_file = os.path.basename(meta_file).split('.')[0] - speaker_name = _meta_file.split('_')[0] - chapter_id = _meta_file.split('_')[1] + _meta_file = os.path.basename(meta_file).split(".")[0] + speaker_name = _meta_file.split("_")[0] + chapter_id = _meta_file.split("_")[1] _root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}") - with open(meta_file, 'r') as ttf: + with open(meta_file, "r") as ttf: for line in ttf: - cols = line.split('\t') - wav_file = os.path.join(_root_path, cols[0] + '.wav') + cols = line.split("\t") + wav_file = os.path.join(_root_path, cols[0] + ".wav") text = cols[1] - items.append([text, wav_file, 'LTTS_' + speaker_name]) + items.append([text, wav_file, "LTTS_" + speaker_name]) for item in items: - assert os.path.exists( - item[1]), f" [!] wav files don't exist - {item[1]}" + assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}" return items @@ -271,11 +268,10 @@ def custom_turkish(root_path, meta_file): items = [] speaker_name = "turkish-female" skipped_files = [] - with open(txt_file, 'r', encoding='utf-8') as ttf: + with open(txt_file, "r", encoding="utf-8") as ttf: for line in ttf: - cols = line.split('|') - wav_file = os.path.join(root_path, 'wavs', - cols[0].strip() + '.wav') + cols = line.split("|") + wav_file = os.path.join(root_path, "wavs", cols[0].strip() + ".wav") if not os.path.exists(wav_file): skipped_files.append(wav_file) continue @@ -287,14 +283,14 @@ def custom_turkish(root_path, meta_file): # ToDo: add the dataset link when the dataset is released publicly def brspeech(root_path, meta_file): - '''BRSpeech 3.0 beta''' + """BRSpeech 3.0 beta""" txt_file = os.path.join(root_path, meta_file) items = [] - with open(txt_file, 'r') as ttf: + with open(txt_file, "r") as ttf: for line in ttf: if line.startswith("wav_filename"): continue - cols = line.split('|') + cols = line.split("|") wav_file = os.path.join(root_path, cols[0]) text = cols[2] speaker_name = cols[3] @@ -302,45 +298,41 @@ def brspeech(root_path, meta_file): return items -def vctk(root_path, meta_files=None, wavs_path='wav48'): +def vctk(root_path, meta_files=None, wavs_path="wav48"): """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" test_speakers = meta_files items = [] meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) for meta_file in meta_files: - _, speaker_id, txt_file = os.path.relpath(meta_file, - root_path).split(os.sep) - file_id = txt_file.split('.')[0] - if isinstance(test_speakers, - list): # if is list ignore this speakers ids + _, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep) + file_id = txt_file.split(".")[0] + if isinstance(test_speakers, list): # if is list ignore this speakers ids if speaker_id in test_speakers: continue with open(meta_file) as file_text: text = file_text.readlines()[0] - wav_file = os.path.join(root_path, wavs_path, speaker_id, - file_id + '.wav') - items.append([text, wav_file, 'VCTK_' + speaker_id]) + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") + items.append([text, wav_file, "VCTK_" + speaker_id]) return items -def vctk_slim(root_path, meta_files=None, wavs_path='wav48'): +def vctk_slim(root_path, meta_files=None, wavs_path="wav48"): """homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz""" items = [] txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True) for text_file in txt_files: - _, speaker_id, txt_file = os.path.relpath(text_file, - root_path).split(os.sep) - file_id = txt_file.split('.')[0] + _, speaker_id, txt_file = os.path.relpath(text_file, root_path).split(os.sep) + file_id = txt_file.split(".")[0] if isinstance(meta_files, list): # if is list ignore this speakers ids if speaker_id in meta_files: continue - wav_file = os.path.join(root_path, wavs_path, speaker_id, - file_id + '.wav') - items.append([None, wav_file, 'VCTK_' + speaker_id]) + wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav") + items.append([None, wav_file, "VCTK_" + speaker_id]) return items + # ======================================== VOX CELEB =========================================== def voxceleb2(root_path, meta_file=None): """ @@ -365,31 +357,33 @@ def _voxcel_x(root_path, meta_file, voxcel_idx): # if not exists meta file, crawl recursively for 'wav' files if meta_file is not None: - with open(str(meta_file), 'r') as f: - return [x.strip().split('|') for x in f.readlines()] + with open(str(meta_file), "r") as f: + return [x.strip().split("|") for x in f.readlines()] elif not cache_to.exists(): cnt = 0 meta_data = [] wav_files = voxceleb_path.rglob("**/*.wav") - for path in tqdm(wav_files, desc=f"Building VoxCeleb {voxcel_idx} Meta file ... this needs to be done only once.", - total=expected_count): + for path in tqdm( + wav_files, + desc=f"Building VoxCeleb {voxcel_idx} Meta file ... this needs to be done only once.", + total=expected_count, + ): speaker_id = str(Path(path).parent.parent.stem) - assert speaker_id.startswith('id') + assert speaker_id.startswith("id") text = None # VoxCel does not provide transciptions, and they are not needed for training the SE meta_data.append(f"{text}|{path}|voxcel{voxcel_idx}_{speaker_id}\n") cnt += 1 - with open(str(cache_to), 'w') as f: + with open(str(cache_to), "w") as f: f.write("".join(meta_data)) if cnt < expected_count: raise ValueError(f"Found too few instances for Voxceleb. Should be around {expected_count}, is: {cnt}") - with open(str(cache_to), 'r') as f: - return [x.strip().split('|') for x in f.readlines()] + with open(str(cache_to), "r") as f: + return [x.strip().split("|") for x in f.readlines()] - -def baker(root_path: str, meta_file: str) -> List[List[str]]: +def baker(root_path: str, meta_file: str) -> List[List[str]]: """Normalizes the Baker meta data file to TTS format Args: @@ -401,9 +395,9 @@ def baker(root_path: str, meta_file: str) -> List[List[str]]: txt_file = os.path.join(root_path, meta_file) items = [] speaker_name = "baker" - with open(txt_file, 'r') as ttf: + with open(txt_file, "r") as ttf: for line in ttf: - wav_name, text = line.rstrip('\n').split("|") + wav_name, text = line.rstrip("\n").split("|") wav_path = os.path.join(root_path, "clips_22", wav_name) items.append([text, wav_path, speaker_name]) return items diff --git a/TTS/tts/layers/align_tts/mdn.py b/TTS/tts/layers/align_tts/mdn.py index f5847cb4..cdb33252 100644 --- a/TTS/tts/layers/align_tts/mdn.py +++ b/TTS/tts/layers/align_tts/mdn.py @@ -5,6 +5,7 @@ class MDNBlock(nn.Module): """Mixture of Density Network implementation https://arxiv.org/pdf/2003.01950.pdf """ + def __init__(self, in_channels, out_channels): super().__init__() self.out_channels = out_channels @@ -24,6 +25,6 @@ class MDNBlock(nn.Module): mu_sigma = self.conv2(o) # TODO: check this sigmoid # mu = torch.sigmoid(mu_sigma[:, :self.out_channels//2, :]) - mu = mu_sigma[:, :self.out_channels//2, :] - log_sigma = mu_sigma[:, self.out_channels//2:, :] + mu = mu_sigma[:, : self.out_channels // 2, :] + log_sigma = mu_sigma[:, self.out_channels // 2 :, :] return mu, log_sigma diff --git a/TTS/tts/layers/feed_forward/decoder.py b/TTS/tts/layers/feed_forward/decoder.py index 5293e8bc..7e145a6c 100644 --- a/TTS/tts/layers/feed_forward/decoder.py +++ b/TTS/tts/layers/feed_forward/decoder.py @@ -31,15 +31,16 @@ class WaveNetDecoder(nn.Module): hidden_channels (int): number of hidden channels for prenet and postnet. params (dict): dictionary for residual convolutional blocks. """ + def __init__(self, in_channels, out_channels, hidden_channels, c_in_channels, params): super().__init__() # prenet - self.prenet = torch.nn.Conv1d(in_channels, params['hidden_channels'], 1) + self.prenet = torch.nn.Conv1d(in_channels, params["hidden_channels"], 1) # wavenet layers - self.wn = WNBlocks(params['hidden_channels'], c_in_channels=c_in_channels, **params) + self.wn = WNBlocks(params["hidden_channels"], c_in_channels=c_in_channels, **params) # postnet self.postnet = [ - torch.nn.Conv1d(params['hidden_channels'], hidden_channels, 1), + torch.nn.Conv1d(params["hidden_channels"], hidden_channels, 1), torch.nn.ReLU(), torch.nn.Conv1d(hidden_channels, hidden_channels, 1), torch.nn.ReLU(), @@ -77,12 +78,12 @@ class RelativePositionTransformerDecoder(nn.Module): hidden_channels (int): number of hidden channels including Transformer layers. params (dict): dictionary for residual convolutional blocks. """ + def __init__(self, in_channels, out_channels, hidden_channels, params): super().__init__() self.prenet = Conv1dBN(in_channels, hidden_channels, 1, 1) - self.rel_pos_transformer = RelativePositionTransformer( - in_channels, out_channels, hidden_channels, **params) + self.rel_pos_transformer = RelativePositionTransformer(in_channels, out_channels, hidden_channels, **params) def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument o = self.prenet(x) * x_mask @@ -107,6 +108,7 @@ class FFTransformerDecoder(nn.Module): hidden_channels (int): number of hidden channels including Transformer layers. params (dict): dictionary for residual convolutional blocks. """ + def __init__(self, in_channels, out_channels, params): super().__init__() @@ -117,7 +119,7 @@ class FFTransformerDecoder(nn.Module): # TODO: handle multi-speaker x_mask = 1 if x_mask is None else x_mask o = self.transformer_block(x) * x_mask - o = self.postnet(o)* x_mask + o = self.postnet(o) * x_mask return o @@ -141,19 +143,15 @@ class ResidualConv1dBNDecoder(nn.Module): hidden_channels (int): number of hidden channels including ResidualConv1dBNBlock layers. params (dict): dictionary for residual convolutional blocks. """ + def __init__(self, in_channels, out_channels, hidden_channels, params): super().__init__() - self.res_conv_block = ResidualConv1dBNBlock(in_channels, - hidden_channels, - hidden_channels, **params) + self.res_conv_block = ResidualConv1dBNBlock(in_channels, hidden_channels, hidden_channels, **params) self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1) self.postnet = nn.Sequential( - Conv1dBNBlock(hidden_channels, - hidden_channels, - hidden_channels, - params['kernel_size'], - 1, - num_conv_blocks=2), + Conv1dBNBlock( + hidden_channels, hidden_channels, hidden_channels, params["kernel_size"], 1, num_conv_blocks=2 + ), nn.Conv1d(hidden_channels, out_channels, 1), ) @@ -178,17 +176,18 @@ class Decoder(nn.Module): # pylint: disable=dangerous-default-value def __init__( - self, - out_channels, - in_hidden_channels, - decoder_type='residual_conv_bn', - decoder_params={ - "kernel_size": 4, - "dilations": 4 * [1, 2, 4, 8] + [1], - "num_conv_blocks": 2, - "num_res_blocks": 17 - }, - c_in_channels=0): + self, + out_channels, + in_hidden_channels, + decoder_type="residual_conv_bn", + decoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17, + }, + c_in_channels=0, + ): super().__init__() if decoder_type.lower() == "relative_position_transformer": @@ -196,23 +195,27 @@ class Decoder(nn.Module): in_channels=in_hidden_channels, out_channels=out_channels, hidden_channels=in_hidden_channels, - params=decoder_params) - elif decoder_type.lower() == 'residual_conv_bn': + params=decoder_params, + ) + elif decoder_type.lower() == "residual_conv_bn": self.decoder = ResidualConv1dBNDecoder( in_channels=in_hidden_channels, out_channels=out_channels, hidden_channels=in_hidden_channels, - params=decoder_params) - elif decoder_type.lower() == 'wavenet': - self.decoder = WaveNetDecoder(in_channels=in_hidden_channels, - out_channels=out_channels, - hidden_channels=in_hidden_channels, - c_in_channels=c_in_channels, - params=decoder_params) - elif decoder_type.lower() == 'fftransformer': + params=decoder_params, + ) + elif decoder_type.lower() == "wavenet": + self.decoder = WaveNetDecoder( + in_channels=in_hidden_channels, + out_channels=out_channels, + hidden_channels=in_hidden_channels, + c_in_channels=c_in_channels, + params=decoder_params, + ) + elif decoder_type.lower() == "fftransformer": self.decoder = FFTransformerDecoder(in_hidden_channels, out_channels, decoder_params) else: - raise ValueError(f'[!] Unknown decoder type - {decoder_type}') + raise ValueError(f"[!] Unknown decoder type - {decoder_type}") def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument """ diff --git a/TTS/tts/layers/feed_forward/duration_predictor.py b/TTS/tts/layers/feed_forward/duration_predictor.py index 5c5c4f3a..5392aeca 100644 --- a/TTS/tts/layers/feed_forward/duration_predictor.py +++ b/TTS/tts/layers/feed_forward/duration_predictor.py @@ -16,16 +16,19 @@ class DurationPredictor(nn.Module): Args: hidden_channels (int): number of channels in the inner layers. """ + def __init__(self, hidden_channels): super().__init__() - self.layers = nn.ModuleList([ - Conv1dBN(hidden_channels, hidden_channels, 4, 1), - Conv1dBN(hidden_channels, hidden_channels, 3, 1), - Conv1dBN(hidden_channels, hidden_channels, 1, 1), - nn.Conv1d(hidden_channels, 1, 1) - ]) + self.layers = nn.ModuleList( + [ + Conv1dBN(hidden_channels, hidden_channels, 4, 1), + Conv1dBN(hidden_channels, hidden_channels, 3, 1), + Conv1dBN(hidden_channels, hidden_channels, 1, 1), + nn.Conv1d(hidden_channels, 1, 1), + ] + ) def forward(self, x, x_mask): """ diff --git a/TTS/tts/layers/feed_forward/encoder.py b/TTS/tts/layers/feed_forward/encoder.py index 6bc46cfa..a50898a4 100644 --- a/TTS/tts/layers/feed_forward/encoder.py +++ b/TTS/tts/layers/feed_forward/encoder.py @@ -1,7 +1,7 @@ from torch import nn from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer -from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock +from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock from TTS.tts.layers.generic.transformer import FFTransformerBlock @@ -16,17 +16,19 @@ class RelativePositionTransformerEncoder(nn.Module): hidden_channels (int): number of hidden channels params (dict): dictionary for residual convolutional blocks. """ + def __init__(self, in_channels, out_channels, hidden_channels, params): super().__init__() - self.prenet = ResidualConv1dBNBlock(in_channels, - hidden_channels, - hidden_channels, - kernel_size=5, - num_res_blocks=3, - num_conv_blocks=1, - dilations=[1, 1, 1]) - self.rel_pos_transformer = RelativePositionTransformer( - hidden_channels, out_channels, hidden_channels, **params) + self.prenet = ResidualConv1dBNBlock( + in_channels, + hidden_channels, + hidden_channels, + kernel_size=5, + num_res_blocks=3, + num_conv_blocks=1, + dilations=[1, 1, 1], + ) + self.rel_pos_transformer = RelativePositionTransformer(hidden_channels, out_channels, hidden_channels, **params) def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument if x_mask is None: @@ -47,20 +49,20 @@ class ResidualConv1dBNEncoder(nn.Module): hidden_channels (int): number of hidden channels params (dict): dictionary for residual convolutional blocks. """ + def __init__(self, in_channels, out_channels, hidden_channels, params): super().__init__() - self.prenet = nn.Sequential(nn.Conv1d(in_channels, hidden_channels, 1), - nn.ReLU()) - self.res_conv_block = ResidualConv1dBNBlock(hidden_channels, - hidden_channels, - hidden_channels, **params) + self.prenet = nn.Sequential(nn.Conv1d(in_channels, hidden_channels, 1), nn.ReLU()) + self.res_conv_block = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **params) - self.postnet = nn.Sequential(*[ - nn.Conv1d(hidden_channels, hidden_channels, 1), - nn.ReLU(), - nn.BatchNorm1d(hidden_channels), - nn.Conv1d(hidden_channels, out_channels, 1) - ]) + self.postnet = nn.Sequential( + *[ + nn.Conv1d(hidden_channels, hidden_channels, 1), + nn.ReLU(), + nn.BatchNorm1d(hidden_channels), + nn.Conv1d(hidden_channels, out_channels, 1), + ] + ) def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument if x_mask is None: @@ -115,18 +117,15 @@ class Encoder(nn.Module): } ``` """ + def __init__( - self, - in_hidden_channels, - out_channels, - encoder_type='residual_conv_bn', - encoder_params={ - "kernel_size": 4, - "dilations": 4 * [1, 2, 4] + [1], - "num_conv_blocks": 2, - "num_res_blocks": 13 - }, - c_in_channels=0): + self, + in_hidden_channels, + out_channels, + encoder_type="residual_conv_bn", + encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13}, + c_in_channels=0, + ): super().__init__() self.out_channels = out_channels self.in_channels = in_hidden_channels @@ -137,21 +136,22 @@ class Encoder(nn.Module): # init encoder if encoder_type.lower() == "relative_position_transformer": # text encoder + # pylint: disable=unexpected-keyword-arg self.encoder = RelativePositionTransformerEncoder( - in_hidden_channels, out_channels, in_hidden_channels, - encoder_params) # pylint: disable=unexpected-keyword-arg - elif encoder_type.lower() == 'residual_conv_bn': - self.encoder = ResidualConv1dBNEncoder(in_hidden_channels, - out_channels, - in_hidden_channels, - encoder_params) - elif encoder_type.lower() == 'fftransformer': - assert in_hidden_channels == out_channels, \ - "[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'" - self.encoder = FFTransformerBlock(in_hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg + in_hidden_channels, out_channels, in_hidden_channels, encoder_params + ) + elif encoder_type.lower() == "residual_conv_bn": + self.encoder = ResidualConv1dBNEncoder(in_hidden_channels, out_channels, in_hidden_channels, encoder_params) + elif encoder_type.lower() == "fftransformer": + assert ( + in_hidden_channels == out_channels + ), "[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'" + # pylint: disable=unexpected-keyword-arg + self.encoder = FFTransformerBlock( + in_hidden_channels, **encoder_params + ) else: - raise NotImplementedError(' [!] unknown encoder type.') - + raise NotImplementedError(" [!] unknown encoder type.") def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument """ diff --git a/TTS/tts/layers/generic/gated_conv.py b/TTS/tts/layers/generic/gated_conv.py index ec95565a..9a29c449 100644 --- a/TTS/tts/layers/generic/gated_conv.py +++ b/TTS/tts/layers/generic/gated_conv.py @@ -10,6 +10,7 @@ class GatedConvBlock(nn.Module): kernel_size (int): convolution kernel size. dropout_p (float): dropout rate. """ + def __init__(self, in_out_channels, kernel_size, dropout_p, num_layers): super().__init__() # class arguments @@ -20,21 +21,14 @@ class GatedConvBlock(nn.Module): self.norm_layers = nn.ModuleList() self.layers = nn.ModuleList() for _ in range(num_layers): - self.conv_layers += [ - nn.Conv1d(in_out_channels, - 2 * in_out_channels, - kernel_size, - padding=kernel_size // 2) - ] + self.conv_layers += [nn.Conv1d(in_out_channels, 2 * in_out_channels, kernel_size, padding=kernel_size // 2)] self.norm_layers += [LayerNorm(2 * in_out_channels)] def forward(self, x, x_mask): o = x res = x for idx in range(self.num_layers): - o = nn.functional.dropout(o, - p=self.dropout_p, - training=self.training) + o = nn.functional.dropout(o, p=self.dropout_p, training=self.training) o = self.conv_layers[idx](o * x_mask) o = self.norm_layers[idx](o) o = nn.functional.glu(o, dim=1) diff --git a/TTS/tts/layers/generic/normalization.py b/TTS/tts/layers/generic/normalization.py index e3dbb52f..fd607b75 100644 --- a/TTS/tts/layers/generic/normalization.py +++ b/TTS/tts/layers/generic/normalization.py @@ -22,24 +22,17 @@ class LayerNorm(nn.Module): def forward(self, x): mean = torch.mean(x, 1, keepdim=True) - variance = torch.mean((x - mean)**2, 1, keepdim=True) + variance = torch.mean((x - mean) ** 2, 1, keepdim=True) x = (x - mean) * torch.rsqrt(variance + self.eps) x = x * self.gamma + self.beta return x class TemporalBatchNorm1d(nn.BatchNorm1d): - """Normalize each channel separately over time and batch. - """ - def __init__(self, - channels, - affine=True, - track_running_stats=True, - momentum=0.1): - super().__init__(channels, - affine=affine, - track_running_stats=track_running_stats, - momentum=momentum) + """Normalize each channel separately over time and batch.""" + + def __init__(self, channels, affine=True, track_running_stats=True, momentum=0.1): + super().__init__(channels, affine=affine, track_running_stats=track_running_stats, momentum=momentum) def forward(self, x): return super().forward(x.transpose(2, 1)).transpose(2, 1) @@ -58,6 +51,7 @@ class ActNorm(nn.Module): - inputs: (B, C, T) - outputs: (B, C, T) """ + def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument super().__init__() self.channels = channels @@ -68,8 +62,7 @@ class ActNorm(nn.Module): def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument if x_mask is None: - x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, - dtype=x.dtype) + x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype) x_len = torch.sum(x_mask, [1, 2]) if not self.initialized: self.initialize(x, x_mask) @@ -95,13 +88,11 @@ class ActNorm(nn.Module): denom = torch.sum(x_mask, [0, 2]) m = torch.sum(x * x_mask, [0, 2]) / denom m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom - v = m_sq - (m**2) + v = m_sq - (m ** 2) logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6)) - bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to( - dtype=self.bias.dtype) - logs_init = (-logs).view(*self.logs.shape).to( - dtype=self.logs.dtype) + bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype) + logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype) self.bias.data.copy_(bias_init) self.logs.data.copy_(logs_init) diff --git a/TTS/tts/layers/generic/pos_encoding.py b/TTS/tts/layers/generic/pos_encoding.py index 95330b4a..a1eaacea 100644 --- a/TTS/tts/layers/generic/pos_encoding.py +++ b/TTS/tts/layers/generic/pos_encoding.py @@ -11,20 +11,20 @@ class PositionalEncoding(nn.Module): channels (int): embedding size dropout (float): dropout parameter """ + def __init__(self, channels, dropout_p=0.0, max_len=5000): super().__init__() if channels % 2 != 0: raise ValueError( - "Cannot use sin/cos positional encoding with " - "odd channels (got channels={:d})".format(channels)) + "Cannot use sin/cos positional encoding with " "odd channels (got channels={:d})".format(channels) + ) pe = torch.zeros(max_len, channels) position = torch.arange(0, max_len).unsqueeze(1) - div_term = torch.pow(10000, - torch.arange(0, channels, 2).float() / channels) + div_term = torch.pow(10000, torch.arange(0, channels, 2).float() / channels) pe[:, 0::2] = torch.sin(position.float() * div_term) pe[:, 1::2] = torch.cos(position.float() * div_term) pe = pe.unsqueeze(0).transpose(1, 2) - self.register_buffer('pe', pe) + self.register_buffer("pe", pe) if dropout_p > 0: self.dropout = nn.Dropout(p=dropout_p) self.channels = channels @@ -43,14 +43,15 @@ class PositionalEncoding(nn.Module): if self.pe.size(2) < x.size(2): raise RuntimeError( f"Sequence is {x.size(2)} but PositionalEncoding is" - f" limited to {self.pe.size(2)}. See max_len argument.") + f" limited to {self.pe.size(2)}. See max_len argument." + ) if mask is not None: - pos_enc = (self.pe[:, :, :x.size(2)] * mask) + pos_enc = self.pe[:, :, : x.size(2)] * mask else: - pos_enc = self.pe[:, :, :x.size(2)] + pos_enc = self.pe[:, :, : x.size(2)] x = x + pos_enc else: x = x + self.pe[:, :, first_idx:last_idx] - if hasattr(self, 'dropout'): + if hasattr(self, "dropout"): x = self.dropout(x) return x diff --git a/TTS/tts/layers/generic/res_conv_bn.py b/TTS/tts/layers/generic/res_conv_bn.py index 964afd0a..30c134cd 100644 --- a/TTS/tts/layers/generic/res_conv_bn.py +++ b/TTS/tts/layers/generic/res_conv_bn.py @@ -3,9 +3,10 @@ from torch import nn class ZeroTemporalPad(nn.Module): """Pad sequences to equal lentgh in the temporal dimension""" + def __init__(self, kernel_size, dilation): super().__init__() - total_pad = (dilation * (kernel_size - 1)) + total_pad = dilation * (kernel_size - 1) begin = total_pad // 2 end = total_pad - begin self.pad_layer = nn.ZeroPad2d((0, 0, begin, end)) @@ -27,9 +28,10 @@ class Conv1dBN(nn.Module): kernel_size (int): kernel size for convolutional filters. dilation (int): dilation for convolution layers. """ + def __init__(self, in_channels, out_channels, kernel_size, dilation): super().__init__() - padding = (dilation * (kernel_size - 1)) + padding = dilation * (kernel_size - 1) pad_s = padding // 2 pad_e = padding - pad_s self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation) @@ -55,14 +57,17 @@ class Conv1dBNBlock(nn.Module): dilation (int): dilation for convolution layers. num_conv_blocks (int, optional): number of convolutional blocks. Defaults to 2. """ + def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation, num_conv_blocks=2): super().__init__() self.conv_bn_blocks = [] for idx in range(num_conv_blocks): - layer = Conv1dBN(in_channels if idx == 0 else hidden_channels, - out_channels if idx == (num_conv_blocks - 1) else hidden_channels, - kernel_size, - dilation) + layer = Conv1dBN( + in_channels if idx == 0 else hidden_channels, + out_channels if idx == (num_conv_blocks - 1) else hidden_channels, + kernel_size, + dilation, + ) self.conv_bn_blocks.append(layer) self.conv_bn_blocks = nn.Sequential(*self.conv_bn_blocks) @@ -91,18 +96,23 @@ class ResidualConv1dBNBlock(nn.Module): num_res_blocks (int, optional): number of residual blocks. Defaults to 13. num_conv_blocks (int, optional): number of convolutional blocks in each residual block. Defaults to 2. """ - def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilations, num_res_blocks=13, num_conv_blocks=2): + + def __init__( + self, in_channels, out_channels, hidden_channels, kernel_size, dilations, num_res_blocks=13, num_conv_blocks=2 + ): super().__init__() assert len(dilations) == num_res_blocks self.res_blocks = nn.ModuleList() for idx, dilation in enumerate(dilations): - block = Conv1dBNBlock(in_channels if idx == 0 else hidden_channels, - out_channels if (idx + 1) == len(dilations) else hidden_channels, - hidden_channels, - kernel_size, - dilation, - num_conv_blocks) + block = Conv1dBNBlock( + in_channels if idx == 0 else hidden_channels, + out_channels if (idx + 1) == len(dilations) else hidden_channels, + hidden_channels, + kernel_size, + dilation, + num_conv_blocks, + ) self.res_blocks.append(block) def forward(self, x, x_mask=None): diff --git a/TTS/tts/layers/generic/time_depth_sep_conv.py b/TTS/tts/layers/generic/time_depth_sep_conv.py index c9a117c8..186cea02 100644 --- a/TTS/tts/layers/generic/time_depth_sep_conv.py +++ b/TTS/tts/layers/generic/time_depth_sep_conv.py @@ -5,12 +5,8 @@ from torch import nn class TimeDepthSeparableConv(nn.Module): """Time depth separable convolution as in https://arxiv.org/pdf/1904.02619.pdf It shows competative results with less computation and memory footprint.""" - def __init__(self, - in_channels, - hid_channels, - out_channels, - kernel_size, - bias=True): + + def __init__(self, in_channels, hid_channels, out_channels, kernel_size, bias=True): super().__init__() self.in_channels = in_channels @@ -62,28 +58,24 @@ class TimeDepthSeparableConv(nn.Module): class TimeDepthSeparableConvBlock(nn.Module): - def __init__(self, - in_channels, - hid_channels, - out_channels, - num_layers, - kernel_size, - bias=True): + def __init__(self, in_channels, hid_channels, out_channels, num_layers, kernel_size, bias=True): super().__init__() assert (kernel_size - 1) % 2 == 0 assert num_layers > 1 self.layers = nn.ModuleList() layer = TimeDepthSeparableConv( - in_channels, hid_channels, - out_channels if num_layers == 1 else hid_channels, kernel_size, - bias) + in_channels, hid_channels, out_channels if num_layers == 1 else hid_channels, kernel_size, bias + ) self.layers.append(layer) for idx in range(num_layers - 1): layer = TimeDepthSeparableConv( - hid_channels, hid_channels, out_channels if - (idx + 1) == (num_layers - 1) else hid_channels, kernel_size, - bias) + hid_channels, + hid_channels, + out_channels if (idx + 1) == (num_layers - 1) else hid_channels, + kernel_size, + bias, + ) self.layers.append(layer) def forward(self, x, mask): diff --git a/TTS/tts/layers/generic/transformer.py b/TTS/tts/layers/generic/transformer.py index 2324938e..24d604f6 100644 --- a/TTS/tts/layers/generic/transformer.py +++ b/TTS/tts/layers/generic/transformer.py @@ -4,16 +4,9 @@ import torch.nn.functional as F class FFTransformer(nn.Module): - def __init__(self, - in_out_channels, - num_heads, - hidden_channels_ffn=1024, - kernel_size_fft=3, - dropout_p=0.1): + def __init__(self, in_out_channels, num_heads, hidden_channels_ffn=1024, kernel_size_fft=3, dropout_p=0.1): super().__init__() - self.self_attn = nn.MultiheadAttention(in_out_channels, - num_heads, - dropout=dropout_p) + self.self_attn = nn.MultiheadAttention(in_out_channels, num_heads, dropout=dropout_p) padding = (kernel_size_fft - 1) // 2 self.conv1 = nn.Conv1d(in_out_channels, hidden_channels_ffn, kernel_size=kernel_size_fft, padding=padding) @@ -27,11 +20,7 @@ class FFTransformer(nn.Module): def forward(self, src, src_mask=None, src_key_padding_mask=None): """😦 ugly looking with all the transposing """ src = src.permute(2, 0, 1) - src2, enc_align = self.self_attn(src, - src, - src, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask) + src2, enc_align = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask) src = self.norm1(src + src2) # T x B x D -> B x D x T src = src.permute(1, 2, 0) @@ -45,15 +34,19 @@ class FFTransformer(nn.Module): class FFTransformerBlock(nn.Module): - def __init__(self, in_out_channels, num_heads, hidden_channels_ffn, - num_layers, dropout_p): + def __init__(self, in_out_channels, num_heads, hidden_channels_ffn, num_layers, dropout_p): super().__init__() - self.fft_layers = nn.ModuleList([ - FFTransformer(in_out_channels=in_out_channels, - num_heads=num_heads, - hidden_channels_ffn=hidden_channels_ffn, - dropout_p=dropout_p) for _ in range(num_layers) - ]) + self.fft_layers = nn.ModuleList( + [ + FFTransformer( + in_out_channels=in_out_channels, + num_heads=num_heads, + hidden_channels_ffn=hidden_channels_ffn, + dropout_p=dropout_p, + ) + for _ in range(num_layers) + ] + ) def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument """ diff --git a/TTS/tts/layers/generic/wavenet.py b/TTS/tts/layers/generic/wavenet.py index 97eee879..0c87e9df 100644 --- a/TTS/tts/layers/generic/wavenet.py +++ b/TTS/tts/layers/generic/wavenet.py @@ -32,15 +32,18 @@ class WN(torch.nn.Module): dropout_p (float): dropout rate. weight_norm (bool): enable/disable weight norm for convolution layers. """ - def __init__(self, - in_channels, - hidden_channels, - kernel_size, - dilation_rate, - num_layers, - c_in_channels=0, - dropout_p=0, - weight_norm=True): + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + c_in_channels=0, + dropout_p=0, + weight_norm=True, + ): super().__init__() assert kernel_size % 2 == 1 assert hidden_channels % 2 == 0 @@ -58,20 +61,16 @@ class WN(torch.nn.Module): # init conditioning layer if c_in_channels > 0: - cond_layer = torch.nn.Conv1d(c_in_channels, - 2 * hidden_channels * num_layers, 1) - self.cond_layer = torch.nn.utils.weight_norm(cond_layer, - name='weight') + cond_layer = torch.nn.Conv1d(c_in_channels, 2 * hidden_channels * num_layers, 1) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") # intermediate layers for i in range(num_layers): - dilation = dilation_rate**i + dilation = dilation_rate ** i padding = int((kernel_size * dilation - dilation) / 2) - in_layer = torch.nn.Conv1d(hidden_channels, - 2 * hidden_channels, - kernel_size, - dilation=dilation, - padding=padding) - in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') + in_layer = torch.nn.Conv1d( + hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding + ) + in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") self.in_layers.append(in_layer) if i < num_layers - 1: @@ -79,10 +78,8 @@ class WN(torch.nn.Module): else: res_skip_channels = hidden_channels - res_skip_layer = torch.nn.Conv1d(hidden_channels, - res_skip_channels, 1) - res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, - name='weight') + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") self.res_skip_layers.append(res_skip_layer) # setup weight norm if not weight_norm: @@ -99,15 +96,14 @@ class WN(torch.nn.Module): x_in = self.dropout(x_in) if g is not None: cond_offset = i * 2 * self.hidden_channels - g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :] + g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :] else: g_l = torch.zeros_like(x_in) - acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, - n_channels_tensor) + acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor) res_skip_acts = self.res_skip_layers[i](acts) if i < self.num_layers - 1: - x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask - output = output + res_skip_acts[:, self.hidden_channels:, :] + x = (x + res_skip_acts[:, : self.hidden_channels, :]) * x_mask + output = output + res_skip_acts[:, self.hidden_channels :, :] else: output = output + res_skip_acts return output * x_mask @@ -140,28 +136,32 @@ class WNBlocks(nn.Module): weight_norm (bool): enable/disable weight norm for convolution layers. """ - def __init__(self, - in_channels, - hidden_channels, - kernel_size, - dilation_rate, - num_blocks, - num_layers, - c_in_channels=0, - dropout_p=0, - weight_norm=True): + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_blocks, + num_layers, + c_in_channels=0, + dropout_p=0, + weight_norm=True, + ): super().__init__() self.wn_blocks = nn.ModuleList() for idx in range(num_blocks): - layer = WN(in_channels=in_channels if idx == 0 else hidden_channels, - hidden_channels=hidden_channels, - kernel_size=kernel_size, - dilation_rate=dilation_rate, - num_layers=num_layers, - c_in_channels=c_in_channels, - dropout_p=dropout_p, - weight_norm=weight_norm) + layer = WN( + in_channels=in_channels if idx == 0 else hidden_channels, + hidden_channels=hidden_channels, + kernel_size=kernel_size, + dilation_rate=dilation_rate, + num_layers=num_layers, + c_in_channels=c_in_channels, + dropout_p=dropout_p, + weight_norm=weight_norm, + ) self.wn_blocks.append(layer) def forward(self, x, x_mask=None, g=None): diff --git a/TTS/tts/layers/glow_tts/decoder.py b/TTS/tts/layers/glow_tts/decoder.py index 46533ed1..3cfcf461 100644 --- a/TTS/tts/layers/glow_tts/decoder.py +++ b/TTS/tts/layers/glow_tts/decoder.py @@ -18,14 +18,12 @@ def squeeze(x, x_mask=None, num_sqz=2): t = (t // num_sqz) * num_sqz x = x[:, :, :t] x_sqz = x.view(b, c, t // num_sqz, num_sqz) - x_sqz = x_sqz.permute(0, 3, 1, - 2).contiguous().view(b, c * num_sqz, t // num_sqz) + x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * num_sqz, t // num_sqz) if x_mask is not None: - x_mask = x_mask[:, :, num_sqz - 1::num_sqz] + x_mask = x_mask[:, :, num_sqz - 1 :: num_sqz] else: - x_mask = torch.ones(b, 1, t // num_sqz).to(device=x.device, - dtype=x.dtype) + x_mask = torch.ones(b, 1, t // num_sqz).to(device=x.device, dtype=x.dtype) return x_sqz * x_mask, x_mask @@ -34,20 +32,16 @@ def unsqueeze(x, x_mask=None, num_sqz=2): Note: each 's' is a n-dimensional vector. - [[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]] """ + [[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]]""" b, c, t = x.size() x_unsqz = x.view(b, num_sqz, c // num_sqz, t) - x_unsqz = x_unsqz.permute(0, 2, 3, - 1).contiguous().view(b, c // num_sqz, - t * num_sqz) + x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // num_sqz, t * num_sqz) if x_mask is not None: - x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, - num_sqz).view(b, 1, t * num_sqz) + x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, num_sqz).view(b, 1, t * num_sqz) else: - x_mask = torch.ones(b, 1, t * num_sqz).to(device=x.device, - dtype=x.dtype) + x_mask = torch.ones(b, 1, t * num_sqz).to(device=x.device, dtype=x.dtype) return x_unsqz * x_mask, x_mask @@ -65,18 +59,21 @@ class Decoder(nn.Module): dropout_p (float): wavenet dropout rate. sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer. """ - def __init__(self, - in_channels, - hidden_channels, - kernel_size, - dilation_rate, - num_flow_blocks, - num_coupling_layers, - dropout_p=0., - num_splits=4, - num_squeeze=2, - sigmoid_scale=False, - c_in_channels=0): + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_flow_blocks, + num_coupling_layers, + dropout_p=0.0, + num_splits=4, + num_squeeze=2, + sigmoid_scale=False, + c_in_channels=0, + ): super().__init__() self.in_channels = in_channels @@ -94,18 +91,19 @@ class Decoder(nn.Module): self.flows = nn.ModuleList() for _ in range(num_flow_blocks): self.flows.append(ActNorm(channels=in_channels * num_squeeze)) + self.flows.append(InvConvNear(channels=in_channels * num_squeeze, num_splits=num_splits)) self.flows.append( - InvConvNear(channels=in_channels * num_squeeze, - num_splits=num_splits)) - self.flows.append( - CouplingBlock(in_channels * num_squeeze, - hidden_channels, - kernel_size=kernel_size, - dilation_rate=dilation_rate, - num_layers=num_coupling_layers, - c_in_channels=c_in_channels, - dropout_p=dropout_p, - sigmoid_scale=sigmoid_scale)) + CouplingBlock( + in_channels * num_squeeze, + hidden_channels, + kernel_size=kernel_size, + dilation_rate=dilation_rate, + num_layers=num_coupling_layers, + c_in_channels=c_in_channels, + dropout_p=dropout_p, + sigmoid_scale=sigmoid_scale, + ) + ) def forward(self, x, x_mask, g=None, reverse=False): if not reverse: diff --git a/TTS/tts/layers/glow_tts/duration_predictor.py b/TTS/tts/layers/glow_tts/duration_predictor.py index a08f64a8..51d1066a 100644 --- a/TTS/tts/layers/glow_tts/duration_predictor.py +++ b/TTS/tts/layers/glow_tts/duration_predictor.py @@ -14,6 +14,7 @@ class DurationPredictor(nn.Module): kernel_size ([type]): [description] dropout_p ([type]): [description] """ + def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p): super().__init__() # class arguments @@ -23,15 +24,9 @@ class DurationPredictor(nn.Module): self.dropout_p = dropout_p # layers self.drop = nn.Dropout(dropout_p) - self.conv_1 = nn.Conv1d(in_channels, - hidden_channels, - kernel_size, - padding=kernel_size // 2) + self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2) self.norm_1 = LayerNorm(hidden_channels) - self.conv_2 = nn.Conv1d(hidden_channels, - hidden_channels, - kernel_size, - padding=kernel_size // 2) + self.conv_2 = nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2) self.norm_2 = LayerNorm(hidden_channels) # output layer self.proj = nn.Conv1d(hidden_channels, 1, 1) diff --git a/TTS/tts/layers/glow_tts/encoder.py b/TTS/tts/layers/glow_tts/encoder.py index 8de006a9..e7c1205f 100644 --- a/TTS/tts/layers/glow_tts/encoder.py +++ b/TTS/tts/layers/glow_tts/encoder.py @@ -69,17 +69,20 @@ class Encoder(nn.Module): 'num_layers': 9, } """ - def __init__(self, - num_chars, - out_channels, - hidden_channels, - hidden_channels_dp, - encoder_type, - encoder_params, - dropout_p_dp=0.1, - mean_only=False, - use_prenet=True, - c_in_channels=0): + + def __init__( + self, + num_chars, + out_channels, + hidden_channels, + hidden_channels_dp, + encoder_type, + encoder_params, + dropout_p_dp=0.1, + mean_only=False, + use_prenet=True, + c_in_channels=0, + ): super().__init__() # class arguments self.num_chars = num_chars @@ -93,47 +96,33 @@ class Encoder(nn.Module): self.encoder_type = encoder_type # embedding layer self.emb = nn.Embedding(num_chars, hidden_channels) - nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5) # init encoder module if encoder_type.lower() == "rel_pos_transformer": if use_prenet: - self.prenet = ResidualConv1dLayerNormBlock(hidden_channels, - hidden_channels, - hidden_channels, - kernel_size=5, - num_layers=3, - dropout_p=0.5) - self.encoder = RelativePositionTransformer(hidden_channels, - hidden_channels, - hidden_channels, - **encoder_params) - elif encoder_type.lower() == 'gated_conv': - self.encoder = GatedConvBlock(hidden_channels, **encoder_params) - elif encoder_type.lower() == 'residual_conv_bn': - if use_prenet: - self.prenet = nn.Sequential( - nn.Conv1d(hidden_channels, hidden_channels, 1), - nn.ReLU() + self.prenet = ResidualConv1dLayerNormBlock( + hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5 ) - self.encoder = ResidualConv1dBNBlock(hidden_channels, - hidden_channels, - hidden_channels, - **encoder_params) - self.postnet = nn.Sequential( - nn.Conv1d(self.hidden_channels, self.hidden_channels, 1), - nn.BatchNorm1d(self.hidden_channels)) - elif encoder_type.lower() == 'time_depth_separable': + self.encoder = RelativePositionTransformer( + hidden_channels, hidden_channels, hidden_channels, **encoder_params + ) + elif encoder_type.lower() == "gated_conv": + self.encoder = GatedConvBlock(hidden_channels, **encoder_params) + elif encoder_type.lower() == "residual_conv_bn": if use_prenet: - self.prenet = ResidualConv1dLayerNormBlock(hidden_channels, - hidden_channels, - hidden_channels, - kernel_size=5, - num_layers=3, - dropout_p=0.5) - self.encoder = TimeDepthSeparableConvBlock(hidden_channels, - hidden_channels, - hidden_channels, - **encoder_params) + self.prenet = nn.Sequential(nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU()) + self.encoder = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **encoder_params) + self.postnet = nn.Sequential( + nn.Conv1d(self.hidden_channels, self.hidden_channels, 1), nn.BatchNorm1d(self.hidden_channels) + ) + elif encoder_type.lower() == "time_depth_separable": + if use_prenet: + self.prenet = ResidualConv1dLayerNormBlock( + hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5 + ) + self.encoder = TimeDepthSeparableConvBlock( + hidden_channels, hidden_channels, hidden_channels, **encoder_params + ) else: raise ValueError(" [!] Unkown encoder type.") @@ -143,8 +132,8 @@ class Encoder(nn.Module): self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1) # duration predictor self.duration_predictor = DurationPredictor( - hidden_channels + c_in_channels, hidden_channels_dp, 3, - dropout_p_dp) + hidden_channels + c_in_channels, hidden_channels_dp, 3, dropout_p_dp + ) def forward(self, x, x_lengths, g=None): """ @@ -159,15 +148,14 @@ class Encoder(nn.Module): # [B, D, T] x = torch.transpose(x, 1, -1) # compute input sequence mask - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), - 1).to(x.dtype) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) # prenet - if hasattr(self, 'prenet') and self.use_prenet: + if hasattr(self, "prenet") and self.use_prenet: x = self.prenet(x, x_mask) # encoder x = self.encoder(x, x_mask) # postnet - if hasattr(self, 'postnet'): + if hasattr(self, "postnet"): x = self.postnet(x) * x_mask # set duration predictor input if g is not None: diff --git a/TTS/tts/layers/glow_tts/glow.py b/TTS/tts/layers/glow_tts/glow.py index c8ad410d..d279ad77 100644 --- a/TTS/tts/layers/glow_tts/glow.py +++ b/TTS/tts/layers/glow_tts/glow.py @@ -7,8 +7,7 @@ from ..generic.normalization import LayerNorm class ResidualConv1dLayerNormBlock(nn.Module): - def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, - num_layers, dropout_p): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, num_layers, dropout_p): """Conv1d with Layer Normalization and residual connection as in GlowTTS paper. https://arxiv.org/pdf/1811.00002.pdf @@ -38,10 +37,10 @@ class ResidualConv1dLayerNormBlock(nn.Module): for idx in range(num_layers): self.conv_layers.append( - nn.Conv1d(in_channels if idx == 0 else hidden_channels, - hidden_channels, - kernel_size, - padding=kernel_size // 2)) + nn.Conv1d( + in_channels if idx == 0 else hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + ) self.norm_layers.append(LayerNorm(hidden_channels)) self.proj = nn.Conv1d(hidden_channels, out_channels, 1) @@ -72,6 +71,7 @@ class InvConvNear(nn.Module): perform 1x1 convolution separately. Cast 1x1 conv operation to 2d by reshaping the input for efficiency. """ + def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument super().__init__() assert num_splits % 2 == 0 @@ -80,8 +80,7 @@ class InvConvNear(nn.Module): self.no_jacobian = no_jacobian self.weight_inv = None - w_init = torch.qr( - torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0] + w_init = torch.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0] if torch.det(w_init) < 0: w_init[:, 0] = -1 * w_init[:, 0] self.weight = nn.Parameter(w_init) @@ -97,28 +96,25 @@ class InvConvNear(nn.Module): assert c % self.num_splits == 0 if x_mask is None: x_mask = 1 - x_len = torch.ones((b, ), dtype=x.dtype, device=x.device) * t + x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t else: x_len = torch.sum(x_mask, [1, 2]) x = x.view(b, 2, c // self.num_splits, self.num_splits // 2, t) - x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.num_splits, - c // self.num_splits, t) + x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.num_splits, c // self.num_splits, t) if reverse: if self.weight_inv is not None: weight = self.weight_inv else: - weight = torch.inverse( - self.weight.float()).to(dtype=self.weight.dtype) + weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) logdet = None else: weight = self.weight if self.no_jacobian: logdet = 0 else: - logdet = torch.logdet( - self.weight) * (c / self.num_splits) * x_len # [b] + logdet = torch.logdet(self.weight) * (c / self.num_splits) * x_len # [b] weight = weight.view(self.num_splits, self.num_splits, 1, 1) z = F.conv2d(x, weight) @@ -128,40 +124,42 @@ class InvConvNear(nn.Module): return z, logdet def store_inverse(self): - weight_inv = torch.inverse( - self.weight.float()).to(dtype=self.weight.dtype) + weight_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype) self.weight_inv = nn.Parameter(weight_inv, requires_grad=False) class CouplingBlock(nn.Module): """Glow Affine Coupling block as in GlowTTS paper. - https://arxiv.org/pdf/1811.00002.pdf + https://arxiv.org/pdf/1811.00002.pdf - x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o - '-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^ + x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o + '-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^ - Args: - in_channels (int): number of input tensor channels. - hidden_channels (int): number of hidden channels. - kernel_size (int): WaveNet filter kernel size. - dilation_rate (int): rate to increase dilation by each layer in a decoder block. - num_layers (int): number of WaveNet layers. - c_in_channels (int): number of conditioning input channels. - dropout_p (int): wavenet dropout rate. - sigmoid_scale (bool): enable/disable sigmoid scaling for output scale. + Args: + in_channels (int): number of input tensor channels. + hidden_channels (int): number of hidden channels. + kernel_size (int): WaveNet filter kernel size. + dilation_rate (int): rate to increase dilation by each layer in a decoder block. + num_layers (int): number of WaveNet layers. + c_in_channels (int): number of conditioning input channels. + dropout_p (int): wavenet dropout rate. + sigmoid_scale (bool): enable/disable sigmoid scaling for output scale. - Note: - It does not use conditional inputs differently from WaveGlow. + Note: + It does not use conditional inputs differently from WaveGlow. """ - def __init__(self, - in_channels, - hidden_channels, - kernel_size, - dilation_rate, - num_layers, - c_in_channels=0, - dropout_p=0, - sigmoid_scale=False): + + def __init__( + self, + in_channels, + hidden_channels, + kernel_size, + dilation_rate, + num_layers, + c_in_channels=0, + dropout_p=0, + sigmoid_scale=False, + ): super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels @@ -183,8 +181,7 @@ class CouplingBlock(nn.Module): end.bias.data.zero_() self.end = end # coupling layers - self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate, - num_layers, c_in_channels, dropout_p) + self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels, dropout_p) def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument """ @@ -195,15 +192,15 @@ class CouplingBlock(nn.Module): """ if x_mask is None: x_mask = 1 - x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:] + x_0, x_1 = x[:, : self.in_channels // 2], x[:, self.in_channels // 2 :] x = self.start(x_0) * x_mask x = self.wn(x, x_mask, g) out = self.end(x) z_0 = x_0 - t = out[:, :self.in_channels // 2, :] - s = out[:, self.in_channels // 2:, :] + t = out[:, : self.in_channels // 2, :] + s = out[:, self.in_channels // 2 :, :] if self.sigmoid_scale: s = torch.log(1e-6 + torch.sigmoid(s + 2)) diff --git a/TTS/tts/layers/glow_tts/monotonic_align/__init__.py b/TTS/tts/layers/glow_tts/monotonic_align/__init__.py index 78fa0fbf..9673e9a2 100644 --- a/TTS/tts/layers/glow_tts/monotonic_align/__init__.py +++ b/TTS/tts/layers/glow_tts/monotonic_align/__init__.py @@ -6,6 +6,7 @@ from TTS.tts.utils.generic_utils import sequence_mask try: # TODO: fix pypi cython installation problem. from TTS.tts.layers.glow_tts.monotonic_align.core import maximum_path_c + CYTHON = True except ModuleNotFoundError: CYTHON = False @@ -30,8 +31,7 @@ def generate_path(duration, mask): cum_duration_flat = cum_duration.view(b * t_x) path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) path = path.view(b, t_x, t_y) - path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0] - ]))[:, :-1] + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] path = path * mask return path @@ -43,7 +43,7 @@ def maximum_path(value, mask): def maximum_path_cython(value, mask): - """ Cython optimised version. + """Cython optimised version. value: [b, t_x, t_y] mask: [b, t_x, t_y] """ diff --git a/TTS/tts/layers/glow_tts/transformer.py b/TTS/tts/layers/glow_tts/transformer.py index 77ea05f9..78b5b2f4 100644 --- a/TTS/tts/layers/glow_tts/transformer.py +++ b/TTS/tts/layers/glow_tts/transformer.py @@ -48,16 +48,19 @@ class RelativePositionMultiHeadAttention(nn.Module): proximal_init (bool, optional): enable/disable poximal init as in the paper. Init key and query layer weights the same. Defaults to False. """ - def __init__(self, - channels, - out_channels, - num_heads, - rel_attn_window_size=None, - heads_share=True, - dropout_p=0., - input_length=None, - proximal_bias=False, - proximal_init=False): + + def __init__( + self, + channels, + out_channels, + num_heads, + rel_attn_window_size=None, + heads_share=True, + dropout_p=0.0, + input_length=None, + proximal_bias=False, + proximal_init=False, + ): super().__init__() assert channels % num_heads == 0, " [!] channels should be divisible by num_heads." @@ -82,15 +85,15 @@ class RelativePositionMultiHeadAttention(nn.Module): # relative positional encoding layers if rel_attn_window_size is not None: n_heads_rel = 1 if heads_share else num_heads - rel_stddev = self.k_channels**-0.5 + rel_stddev = self.k_channels ** -0.5 emb_rel_k = nn.Parameter( - torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, - self.k_channels) * rel_stddev) + torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev + ) emb_rel_v = nn.Parameter( - torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, - self.k_channels) * rel_stddev) - self.register_parameter('emb_rel_k', emb_rel_k) - self.register_parameter('emb_rel_v', emb_rel_v) + torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev + ) + self.register_parameter("emb_rel_k", emb_rel_k) + self.register_parameter("emb_rel_v", emb_rel_v) # init layers nn.init.xavier_uniform_(self.conv_q.weight) @@ -112,38 +115,30 @@ class RelativePositionMultiHeadAttention(nn.Module): def attention(self, query, key, value, mask=None): # reshape [b, d, t] -> [b, n_h, t, d_k] b, d, t_s, t_t = (*key.size(), query.size(2)) - query = query.view(b, self.num_heads, self.k_channels, - t_t).transpose(2, 3) + query = query.view(b, self.num_heads, self.k_channels, t_t).transpose(2, 3) key = key.view(b, self.num_heads, self.k_channels, t_s).transpose(2, 3) - value = value.view(b, self.num_heads, self.k_channels, - t_s).transpose(2, 3) + value = value.view(b, self.num_heads, self.k_channels, t_s).transpose(2, 3) # compute raw attention scores - scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt( - self.k_channels) + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) # relative positional encoding for scores if self.rel_attn_window_size is not None: assert t_s == t_t, "Relative attention is only available for self-attention." # get relative key embeddings - key_relative_embeddings = self._get_relative_embeddings( - self.emb_rel_k, t_s) - rel_logits = self._matmul_with_relative_keys( - query, key_relative_embeddings) - rel_logits = self._relative_position_to_absolute_position( - rel_logits) + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings) + rel_logits = self._relative_position_to_absolute_position(rel_logits) scores_local = rel_logits / math.sqrt(self.k_channels) scores = scores + scores_local # proximan bias if self.proximal_bias: assert t_s == t_t, "Proximal bias is only available for self-attention." - scores = scores + self._attn_proximity_bias(t_s).to( - device=scores.device, dtype=scores.dtype) + scores = scores + self._attn_proximity_bias(t_s).to(device=scores.device, dtype=scores.dtype) # attention score masking if mask is not None: # add small value to prevent oor error. scores = scores.masked_fill(mask == 0, -1e4) if self.input_length is not None: - block_mask = torch.ones_like(scores).triu( - -1 * self.input_length).tril(self.input_length) + block_mask = torch.ones_like(scores).triu(-1 * self.input_length).tril(self.input_length) scores = scores * block_mask + -1e4 * (1 - block_mask) # attention score normalization p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] @@ -153,14 +148,10 @@ class RelativePositionMultiHeadAttention(nn.Module): output = torch.matmul(p_attn, value) # relative positional encoding for values if self.rel_attn_window_size is not None: - relative_weights = self._absolute_position_to_relative_position( - p_attn) - value_relative_embeddings = self._get_relative_embeddings( - self.emb_rel_v, t_s) - output = output + self._matmul_with_relative_values( - relative_weights, value_relative_embeddings) - output = output.transpose(2, 3).contiguous().view( - b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] return output, p_attn @staticmethod @@ -195,20 +186,16 @@ class RelativePositionMultiHeadAttention(nn.Module): return logits def _get_relative_embeddings(self, relative_embeddings, length): - """Convert embedding vestors to a tensor of embeddings - """ + """Convert embedding vestors to a tensor of embeddings""" # Pad first before slice to avoid using cond ops. pad_length = max(length - (self.rel_attn_window_size + 1), 0) slice_start_position = max((self.rel_attn_window_size + 1) - length, 0) slice_end_position = slice_start_position + 2 * length - 1 if pad_length > 0: - padded_relative_embeddings = F.pad( - relative_embeddings, [0, 0, pad_length, pad_length, 0, 0]) + padded_relative_embeddings = F.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0]) else: padded_relative_embeddings = relative_embeddings - used_relative_embeddings = padded_relative_embeddings[:, - slice_start_position: - slice_end_position] + used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position] return used_relative_embeddings @staticmethod @@ -226,8 +213,7 @@ class RelativePositionMultiHeadAttention(nn.Module): x_flat = x.view([batch, heads, length * 2 * length]) x_flat = F.pad(x_flat, [0, length - 1, 0, 0, 0, 0]) # Reshape and slice out the padded elements. - x_final = x_flat.view([batch, heads, length + 1, - 2 * length - 1])[:, :, :length, length - 1:] + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :] return x_final @staticmethod @@ -239,7 +225,7 @@ class RelativePositionMultiHeadAttention(nn.Module): batch, heads, length, _ = x.size() # padd along column x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0]) - x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)]) # add 0's in the beginning that will skew the elements after reshape x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0]) x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] @@ -267,19 +253,15 @@ class RelativePositionMultiHeadAttention(nn.Module): class FeedForwardNetwork(nn.Module): """Feed Forward Inner layers for Transformer. - Args: - in_channels (int): input tensor channels. - out_channels (int): output tensor channels. - hidden_channels (int): inner layers hidden channels. - kernel_size (int): conv1d filter kernel size. - dropout_p (float, optional): dropout rate. Defaults to 0. + Args: + in_channels (int): input tensor channels. + out_channels (int): output tensor channels. + hidden_channels (int): inner layers hidden channels. + kernel_size (int): conv1d filter kernel size. + dropout_p (float, optional): dropout rate. Defaults to 0. """ - def __init__(self, - in_channels, - out_channels, - hidden_channels, - kernel_size, - dropout_p=0.): + + def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0): super().__init__() self.in_channels = in_channels @@ -288,14 +270,8 @@ class FeedForwardNetwork(nn.Module): self.kernel_size = kernel_size self.dropout_p = dropout_p - self.conv_1 = nn.Conv1d(in_channels, - hidden_channels, - kernel_size, - padding=kernel_size // 2) - self.conv_2 = nn.Conv1d(hidden_channels, - out_channels, - kernel_size, - padding=kernel_size // 2) + self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2) + self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size, padding=kernel_size // 2) self.dropout = nn.Dropout(dropout_p) def forward(self, x, x_mask): @@ -308,34 +284,37 @@ class FeedForwardNetwork(nn.Module): class RelativePositionTransformer(nn.Module): """Transformer with Relative Potional Encoding. - https://arxiv.org/abs/1803.02155 + https://arxiv.org/abs/1803.02155 - Args: - in_channels (int): number of channels of the input tensor. - out_chanels (int): number of channels of the output tensor. - hidden_channels (int): model hidden channels. - hidden_channels_ffn (int): hidden channels of FeedForwardNetwork. - num_heads (int): number of attention heads. - num_layers (int): number of transformer layers. - kernel_size (int, optional): kernel size of feed-forward inner layers. Defaults to 1. - dropout_p (float, optional): dropout rate for self-attention and feed-forward inner layers_per_stack. Defaults to 0. - rel_attn_window_size (int, optional): relation attention window size. - If 4, for each time step next and previous 4 time steps are attended. - If default, relative encoding is disabled and it is a regular transformer. - Defaults to None. - input_length (int, optional): input lenght to limit position encoding. Defaults to None. + Args: + in_channels (int): number of channels of the input tensor. + out_chanels (int): number of channels of the output tensor. + hidden_channels (int): model hidden channels. + hidden_channels_ffn (int): hidden channels of FeedForwardNetwork. + num_heads (int): number of attention heads. + num_layers (int): number of transformer layers. + kernel_size (int, optional): kernel size of feed-forward inner layers. Defaults to 1. + dropout_p (float, optional): dropout rate for self-attention and feed-forward inner layers_per_stack. Defaults to 0. + rel_attn_window_size (int, optional): relation attention window size. + If 4, for each time step next and previous 4 time steps are attended. + If default, relative encoding is disabled and it is a regular transformer. + Defaults to None. + input_length (int, optional): input lenght to limit position encoding. Defaults to None. """ - def __init__(self, - in_channels, - out_channels, - hidden_channels, - hidden_channels_ffn, - num_heads, - num_layers, - kernel_size=1, - dropout_p=0., - rel_attn_window_size=None, - input_length=None): + + def __init__( + self, + in_channels, + out_channels, + hidden_channels, + hidden_channels_ffn, + num_heads, + num_layers, + kernel_size=1, + dropout_p=0.0, + rel_attn_window_size=None, + input_length=None, + ): super().__init__() self.hidden_channels = hidden_channels self.hidden_channels_ffn = hidden_channels_ffn @@ -359,7 +338,9 @@ class RelativePositionTransformer(nn.Module): num_heads, rel_attn_window_size=rel_attn_window_size, dropout_p=dropout_p, - input_length=input_length)) + input_length=input_length, + ) + ) self.norm_layers_1.append(LayerNorm(hidden_channels)) if hidden_channels != out_channels and (idx + 1) == self.num_layers: @@ -368,15 +349,14 @@ class RelativePositionTransformer(nn.Module): self.ffn_layers.append( FeedForwardNetwork( hidden_channels, - hidden_channels if - (idx + 1) != self.num_layers else out_channels, + hidden_channels if (idx + 1) != self.num_layers else out_channels, hidden_channels_ffn, kernel_size, - dropout_p=dropout_p)) + dropout_p=dropout_p, + ) + ) - self.norm_layers_2.append( - LayerNorm(hidden_channels if ( - idx + 1) != self.num_layers else out_channels)) + self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels)) def forward(self, x, x_mask): """ @@ -394,7 +374,7 @@ class RelativePositionTransformer(nn.Module): y = self.ffn_layers[i](x, x_mask) y = self.dropout(y) - if (i + 1) == self.num_layers and hasattr(self, 'proj'): + if (i + 1) == self.num_layers and hasattr(self, "proj"): x = self.proj(x) x = self.norm_layers_2[i](x + y) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 03baf488..00399514 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -34,26 +34,23 @@ class L1LossMasked(nn.Module): """ # mask: (batch, max_len, 1) target.requires_grad = False - mask = sequence_mask(sequence_length=length, - max_len=target.size(1)).unsqueeze(2).float() + mask = sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() if self.seq_len_norm: norm_w = mask / mask.sum(dim=1, keepdim=True) out_weights = norm_w.div(target.shape[0] * target.shape[2]) mask = mask.expand_as(x) - loss = functional.l1_loss(x * mask, - target * mask, - reduction='none') + loss = functional.l1_loss(x * mask, target * mask, reduction="none") loss = loss.mul(out_weights.to(loss.device)).sum() else: mask = mask.expand_as(x) - loss = functional.l1_loss(x * mask, target * mask, reduction='sum') + loss = functional.l1_loss(x * mask, target * mask, reduction="sum") loss = loss / mask.sum() return loss class MSELossMasked(nn.Module): def __init__(self, seq_len_norm): - super(MSELossMasked, self).__init__() + super().__init__() self.seq_len_norm = seq_len_norm def forward(self, x, target, length): @@ -76,27 +73,23 @@ class MSELossMasked(nn.Module): """ # mask: (batch, max_len, 1) target.requires_grad = False - mask = sequence_mask(sequence_length=length, - max_len=target.size(1)).unsqueeze(2).float() + mask = sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2).float() if self.seq_len_norm: norm_w = mask / mask.sum(dim=1, keepdim=True) out_weights = norm_w.div(target.shape[0] * target.shape[2]) mask = mask.expand_as(x) - loss = functional.mse_loss(x * mask, - target * mask, - reduction='none') + loss = functional.mse_loss(x * mask, target * mask, reduction="none") loss = loss.mul(out_weights.to(loss.device)).sum() else: mask = mask.expand_as(x) - loss = functional.mse_loss(x * mask, - target * mask, - reduction='sum') + loss = functional.mse_loss(x * mask, target * mask, reduction="sum") loss = loss / mask.sum() return loss class SSIMLoss(torch.nn.Module): """SSIM loss as explained here https://en.wikipedia.org/wiki/Structural_similarity""" + def __init__(self): super().__init__() self.loss_func = ssim @@ -115,9 +108,7 @@ class SSIMLoss(torch.nn.Module): loss: An average loss value in range [0, 1] masked by the length. """ if length is not None: - m = sequence_mask(sequence_length=length, - max_len=y.size(1)).unsqueeze(2).float().to( - y_hat.device) + m = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float().to(y_hat.device) y_hat, y = y_hat * m, y * m return 1 - self.loss_func(y_hat.unsqueeze(1), y.unsqueeze(1)) @@ -139,7 +130,7 @@ class AttentionEntropyLoss(nn.Module): class BCELossMasked(nn.Module): def __init__(self, pos_weight): - super(BCELossMasked, self).__init__() + super().__init__() self.pos_weight = pos_weight def forward(self, x, target, length): @@ -163,25 +154,20 @@ class BCELossMasked(nn.Module): # mask: (batch, max_len, 1) target.requires_grad = False if length is not None: - mask = sequence_mask(sequence_length=length, - max_len=target.size(1)).float() + mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float() x = x * mask target = target * mask num_items = mask.sum() else: num_items = torch.numel(x) - loss = functional.binary_cross_entropy_with_logits( - x, - target, - pos_weight=self.pos_weight, - reduction='sum') + loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum") loss = loss / num_items return loss class DifferentailSpectralLoss(nn.Module): """Differential Spectral Loss - https://arxiv.org/ftp/arxiv/papers/1909/1909.10302.pdf""" + https://arxiv.org/ftp/arxiv/papers/1909/1909.10302.pdf""" def __init__(self, loss_func): super().__init__() @@ -200,12 +186,12 @@ class DifferentailSpectralLoss(nn.Module): target_diff = target[:, 1:] - target[:, :-1] if length is None: return self.loss_func(x_diff, target_diff) - return self.loss_func(x_diff, target_diff, length-1) + return self.loss_func(x_diff, target_diff, length - 1) class GuidedAttentionLoss(torch.nn.Module): def __init__(self, sigma=0.4): - super(GuidedAttentionLoss, self).__init__() + super().__init__() self.sigma = sigma def _make_ga_masks(self, ilens, olens): @@ -214,8 +200,7 @@ class GuidedAttentionLoss(torch.nn.Module): max_olen = max(olens) ga_masks = torch.zeros((B, max_olen, max_ilen)) for idx, (ilen, olen) in enumerate(zip(ilens, olens)): - ga_masks[idx, :olen, :ilen] = self._make_ga_mask( - ilen, olen, self.sigma) + ga_masks[idx, :olen, :ilen] = self._make_ga_mask(ilen, olen, self.sigma) return ga_masks def forward(self, att_ws, ilens, olens): @@ -229,8 +214,7 @@ class GuidedAttentionLoss(torch.nn.Module): def _make_ga_mask(ilen, olen, sigma): grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen)) grid_x, grid_y = grid_x.float(), grid_y.float() - return 1.0 - torch.exp(-(grid_y / ilen - grid_x / olen)**2 / - (2 * (sigma**2))) + return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2))) @staticmethod def _make_masks(ilens, olens): @@ -249,18 +233,19 @@ class Huber(nn.Module): length: B """ mask = sequence_mask(sequence_length=length, max_len=y.size(1)).float() - return torch.nn.functional.smooth_l1_loss( - x * mask, y * mask, reduction='sum') / mask.sum() + return torch.nn.functional.smooth_l1_loss(x * mask, y * mask, reduction="sum") / mask.sum() ######################## # MODEL LOSS LAYERS ######################## + class TacotronLoss(torch.nn.Module): """Collection of Tacotron set-up based on provided config.""" + def __init__(self, c, stopnet_pos_weight=10, ga_sigma=0.4): - super(TacotronLoss, self).__init__() + super().__init__() self.stopnet_pos_weight = stopnet_pos_weight self.ga_alpha = c.ga_alpha self.decoder_diff_spec_alpha = c.decoder_diff_spec_alpha @@ -273,12 +258,9 @@ class TacotronLoss(torch.nn.Module): # postnet and decoder loss if c.loss_masking: - self.criterion = L1LossMasked(c.seq_len_norm) if c.model in [ - "Tacotron" - ] else MSELossMasked(c.seq_len_norm) + self.criterion = L1LossMasked(c.seq_len_norm) if c.model in ["Tacotron"] else MSELossMasked(c.seq_len_norm) else: - self.criterion = nn.L1Loss() if c.model in ["Tacotron" - ] else nn.MSELoss() + self.criterion = nn.L1Loss() if c.model in ["Tacotron"] else nn.MSELoss() # guided attention loss if c.ga_alpha > 0: self.criterion_ga = GuidedAttentionLoss(sigma=ga_sigma) @@ -290,13 +272,23 @@ class TacotronLoss(torch.nn.Module): self.criterion_ssim = SSIMLoss() # stopnet loss # pylint: disable=not-callable - self.criterion_st = BCELossMasked( - pos_weight=torch.tensor(stopnet_pos_weight)) if c.stopnet else None - - def forward(self, postnet_output, decoder_output, mel_input, linear_input, - stopnet_output, stopnet_target, output_lens, decoder_b_output, - alignments, alignment_lens, alignments_backwards, input_lens): + self.criterion_st = BCELossMasked(pos_weight=torch.tensor(stopnet_pos_weight)) if c.stopnet else None + def forward( + self, + postnet_output, + decoder_output, + mel_input, + linear_input, + stopnet_output, + stopnet_target, + output_lens, + decoder_b_output, + alignments, + alignment_lens, + alignments_backwards, + input_lens, + ): # decoder outputs linear or mel spectrograms for Tacotron and Tacotron2 # the target should be set acccordingly @@ -309,85 +301,80 @@ class TacotronLoss(torch.nn.Module): # decoder and postnet losses if self.config.loss_masking: if self.decoder_alpha > 0: - decoder_loss = self.criterion(decoder_output, mel_input, - output_lens) + decoder_loss = self.criterion(decoder_output, mel_input, output_lens) if self.postnet_alpha > 0: - postnet_loss = self.criterion(postnet_output, postnet_target, - output_lens) + postnet_loss = self.criterion(postnet_output, postnet_target, output_lens) else: if self.decoder_alpha > 0: decoder_loss = self.criterion(decoder_output, mel_input) if self.postnet_alpha > 0: postnet_loss = self.criterion(postnet_output, postnet_target) loss = self.decoder_alpha * decoder_loss + self.postnet_alpha * postnet_loss - return_dict['decoder_loss'] = decoder_loss - return_dict['postnet_loss'] = postnet_loss + return_dict["decoder_loss"] = decoder_loss + return_dict["postnet_loss"] = postnet_loss # stopnet loss - stop_loss = self.criterion_st( - stopnet_output, stopnet_target, - output_lens) if self.config.stopnet else torch.zeros(1) + stop_loss = ( + self.criterion_st(stopnet_output, stopnet_target, output_lens) if self.config.stopnet else torch.zeros(1) + ) if not self.config.separate_stopnet and self.config.stopnet: loss += stop_loss - return_dict['stopnet_loss'] = stop_loss + return_dict["stopnet_loss"] = stop_loss # backward decoder loss (if enabled) if self.config.bidirectional_decoder: if self.config.loss_masking: - decoder_b_loss = self.criterion( - torch.flip(decoder_b_output, dims=(1, )), mel_input, - output_lens) + decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1,)), mel_input, output_lens) else: - decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1, )), mel_input) - decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_b_output, dims=(1, )), decoder_output) + decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1,)), mel_input) + decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_b_output, dims=(1,)), decoder_output) loss += self.decoder_alpha * (decoder_b_loss + decoder_c_loss) - return_dict['decoder_b_loss'] = decoder_b_loss - return_dict['decoder_c_loss'] = decoder_c_loss + return_dict["decoder_b_loss"] = decoder_b_loss + return_dict["decoder_c_loss"] = decoder_c_loss # double decoder consistency loss (if enabled) if self.config.double_decoder_consistency: if self.config.loss_masking: - decoder_b_loss = self.criterion(decoder_b_output, mel_input, - output_lens) + decoder_b_loss = self.criterion(decoder_b_output, mel_input, output_lens) else: decoder_b_loss = self.criterion(decoder_b_output, mel_input) # decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output) attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards) loss += self.decoder_alpha * (decoder_b_loss + attention_c_loss) - return_dict['decoder_coarse_loss'] = decoder_b_loss - return_dict['decoder_ddc_loss'] = attention_c_loss + return_dict["decoder_coarse_loss"] = decoder_b_loss + return_dict["decoder_ddc_loss"] = attention_c_loss # guided attention loss (if enabled) if self.config.ga_alpha > 0: ga_loss = self.criterion_ga(alignments, input_lens, alignment_lens) loss += ga_loss * self.ga_alpha - return_dict['ga_loss'] = ga_loss + return_dict["ga_loss"] = ga_loss # decoder differential spectral loss if self.config.decoder_diff_spec_alpha > 0: decoder_diff_spec_loss = self.criterion_diff_spec(decoder_output, mel_input, output_lens) loss += decoder_diff_spec_loss * self.decoder_diff_spec_alpha - return_dict['decoder_diff_spec_loss'] = decoder_diff_spec_loss + return_dict["decoder_diff_spec_loss"] = decoder_diff_spec_loss # postnet differential spectral loss if self.config.postnet_diff_spec_alpha > 0: postnet_diff_spec_loss = self.criterion_diff_spec(postnet_output, postnet_target, output_lens) loss += postnet_diff_spec_loss * self.postnet_diff_spec_alpha - return_dict['postnet_diff_spec_loss'] = postnet_diff_spec_loss + return_dict["postnet_diff_spec_loss"] = postnet_diff_spec_loss # decoder ssim loss if self.config.decoder_ssim_alpha > 0: decoder_ssim_loss = self.criterion_ssim(decoder_output, mel_input, output_lens) loss += decoder_ssim_loss * self.postnet_ssim_alpha - return_dict['decoder_ssim_loss'] = decoder_ssim_loss + return_dict["decoder_ssim_loss"] = decoder_ssim_loss # postnet ssim loss if self.config.postnet_ssim_alpha > 0: postnet_ssim_loss = self.criterion_ssim(postnet_output, postnet_target, output_lens) loss += postnet_ssim_loss * self.postnet_ssim_alpha - return_dict['postnet_ssim_loss'] = postnet_ssim_loss + return_dict["postnet_ssim_loss"] = postnet_ssim_loss - return_dict['loss'] = loss + return_dict["loss"] = loss # check if any loss is NaN for key, loss in return_dict.items(): @@ -401,22 +388,18 @@ class GlowTTSLoss(torch.nn.Module): super().__init__() self.constant_factor = 0.5 * math.log(2 * math.pi) - def forward(self, z, means, scales, log_det, y_lengths, o_dur_log, - o_attn_dur, x_lengths): + def forward(self, z, means, scales, log_det, y_lengths, o_dur_log, o_attn_dur, x_lengths): return_dict = {} # flow loss - neg log likelihood - pz = torch.sum(scales) + 0.5 * torch.sum( - torch.exp(-2 * scales) * (z - means)**2) - log_mle = self.constant_factor + (pz - torch.sum(log_det)) / ( - torch.sum(y_lengths) * z.shape[1]) + pz = torch.sum(scales) + 0.5 * torch.sum(torch.exp(-2 * scales) * (z - means) ** 2) + log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (torch.sum(y_lengths) * z.shape[1]) # duration loss - MSE # loss_dur = torch.sum((o_dur_log - o_attn_dur)**2) / torch.sum(x_lengths) # duration loss - huber loss - loss_dur = torch.nn.functional.smooth_l1_loss( - o_dur_log, o_attn_dur, reduction='sum') / torch.sum(x_lengths) - return_dict['loss'] = log_mle + loss_dur - return_dict['log_mle'] = log_mle - return_dict['loss_dur'] = loss_dur + loss_dur = torch.nn.functional.smooth_l1_loss(o_dur_log, o_attn_dur, reduction="sum") / torch.sum(x_lengths) + return_dict["loss"] = log_mle + loss_dur + return_dict["log_mle"] = log_mle + return_dict["loss_dur"] = loss_dur # check if any loss is NaN for key, loss in return_dict.items(): @@ -441,7 +424,7 @@ class SpeedySpeechLoss(nn.Module): ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) huber_loss = self.huber(dur_output, dur_target, input_lens) loss = self.l1_alpha * l1_loss + self.ssim_alpha * ssim_loss + self.huber_alpha * huber_loss - return {'loss': loss, 'loss_l1': l1_loss, 'loss_ssim': ssim_loss, 'loss_dur': huber_loss} + return {"loss": loss, "loss_l1": l1_loss, "loss_ssim": ssim_loss, "loss_dur": huber_loss} def mse_loss_custom(x, y): @@ -452,26 +435,27 @@ def mse_loss_custom(x, y): class MDNLoss(nn.Module): - """Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf. - """ + """Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf.""" def forward(self, logp, text_lengths, mel_lengths): # pylint: disable=no-self-use - ''' + """ Shapes: mu: [B, D, T] log_sigma: [B, D, T] mel_spec: [B, D, T] - ''' + """ B, T_seq, T_mel = logp.shape - log_alpha = logp.new_ones(B, T_seq, T_mel)*(-1e4) + log_alpha = logp.new_ones(B, T_seq, T_mel) * (-1e4) log_alpha[:, 0, 0] = logp[:, 0, 0] for t in range(1, T_mel): - prev_step = torch.cat([log_alpha[:, :, t-1:t], functional.pad(log_alpha[:, :, t-1:t], - (0, 0, 1, -1), value=-1e4)], dim=-1) + prev_step = torch.cat( + [log_alpha[:, :, t - 1 : t], functional.pad(log_alpha[:, :, t - 1 : t], (0, 0, 1, -1), value=-1e4)], + dim=-1, + ) log_alpha[:, :, t] = torch.logsumexp(prev_step + 1e-4, dim=-1) + logp[:, :, t] - alpha_last = log_alpha[torch.arange(B), text_lengths-1, mel_lengths-1] + alpha_last = log_alpha[torch.arange(B), text_lengths - 1, mel_lengths - 1] mdn_loss = -alpha_last.mean() / T_seq - return mdn_loss#, log_prob_matrix + return mdn_loss # , log_prob_matrix class AlignTTSLoss(nn.Module): @@ -487,6 +471,7 @@ class AlignTTSLoss(nn.Module): Args: c (dict): TTS model configuration. """ + def __init__(self, c): super().__init__() self.mdn_loss = MDNLoss() @@ -499,10 +484,10 @@ class AlignTTSLoss(nn.Module): self.spec_loss_alpha = c.spec_loss_alpha self.mdn_alpha = c.mdn_alpha - def forward(self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, - input_lens, step, phase): - ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas( - step) + def forward( + self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step, phase + ): + ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(step) spec_loss, ssim_loss, dur_loss, mdn_loss = 0, 0, 0, 0 if phase == 0: mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens) @@ -521,11 +506,11 @@ class AlignTTSLoss(nn.Module): ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens) dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens) loss = spec_loss_alpha * spec_loss + ssim_alpha * ssim_loss + dur_loss_alpha * dur_loss + mdn_alpha * mdn_loss - return {'loss': loss, 'loss_l1': spec_loss, 'loss_ssim': ssim_loss, 'loss_dur': dur_loss, 'mdn_loss': mdn_loss} + return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss} @staticmethod def _set_alpha(step, alpha_settings): - '''Set the loss alpha wrt number of steps. + """Set the loss alpha wrt number of steps. Return the corresponding value if no schedule is set. Example: @@ -536,7 +521,7 @@ class AlignTTSLoss(nn.Module): Args: step (int): number of training steps. alpha_settings (int or list): constant alpha value or a list defining the schedule as explained above. - ''' + """ return_alpha = None if isinstance(alpha_settings, list): for key, alpha in alpha_settings: @@ -547,8 +532,7 @@ class AlignTTSLoss(nn.Module): return return_alpha def set_alphas(self, step): - '''Set the alpha values for all the loss functions - ''' + """Set the alpha values for all the loss functions""" ssim_alpha = self._set_alpha(step, self.ssim_alpha) dur_loss_alpha = self._set_alpha(step, self.dur_loss_alpha) spec_loss_alpha = self._set_alpha(step, self.spec_loss_alpha) diff --git a/TTS/tts/layers/tacotron/attentions.py b/TTS/tts/layers/tacotron/attentions.py index 1f682e4c..320a8509 100644 --- a/TTS/tts/layers/tacotron/attentions.py +++ b/TTS/tts/layers/tacotron/attentions.py @@ -14,20 +14,18 @@ class LocationLayer(nn.Module): attention_n_filters (int, optional): number of filters in convolution. Defaults to 32. attention_kernel_size (int, optional): kernel size of convolution filter. Defaults to 31. """ - def __init__(self, - attention_dim, - attention_n_filters=32, - attention_kernel_size=31): - super(LocationLayer, self).__init__() + + def __init__(self, attention_dim, attention_n_filters=32, attention_kernel_size=31): + super().__init__() self.location_conv1d = nn.Conv1d( in_channels=2, out_channels=attention_n_filters, kernel_size=attention_kernel_size, stride=1, padding=(attention_kernel_size - 1) // 2, - bias=False) - self.location_dense = Linear( - attention_n_filters, attention_dim, bias=False, init_gain='tanh') + bias=False, + ) + self.location_dense = Linear(attention_n_filters, attention_dim, bias=False, init_gain="tanh") def forward(self, attention_cat): """ @@ -35,8 +33,7 @@ class LocationLayer(nn.Module): attention_cat: [B, 2, C] """ processed_attention = self.location_conv1d(attention_cat) - processed_attention = self.location_dense( - processed_attention.transpose(1, 2)) + processed_attention = self.location_dense(processed_attention.transpose(1, 2)) return processed_attention @@ -49,31 +46,31 @@ class GravesAttention(nn.Module): query_dim (int): number of channels in query tensor. K (int): number of Gaussian heads to be used for computing attention. """ + COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi)) def __init__(self, query_dim, K): - super(GravesAttention, self).__init__() + super().__init__() self._mask_value = 1e-8 self.K = K # self.attention_alignment = 0.05 self.eps = 1e-5 self.J = None self.N_a = nn.Sequential( - nn.Linear(query_dim, query_dim, bias=True), - nn.ReLU(), - nn.Linear(query_dim, 3*K, bias=True)) + nn.Linear(query_dim, query_dim, bias=True), nn.ReLU(), nn.Linear(query_dim, 3 * K, bias=True) + ) self.attention_weights = None self.mu_prev = None self.init_layers() def init_layers(self): - torch.nn.init.constant_(self.N_a[2].bias[(2*self.K):(3*self.K)], 1.) # bias mean - torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10) # bias std + torch.nn.init.constant_(self.N_a[2].bias[(2 * self.K) : (3 * self.K)], 1.0) # bias mean + torch.nn.init.constant_(self.N_a[2].bias[self.K : (2 * self.K)], 10) # bias std def init_states(self, inputs): - if self.J is None or inputs.shape[1]+1 > self.J.shape[-1]: - self.J = torch.arange(0, inputs.shape[1]+2.0).to(inputs.device) + 0.5 + if self.J is None or inputs.shape[1] + 1 > self.J.shape[-1]: + self.J = torch.arange(0, inputs.shape[1] + 2.0).to(inputs.device) + 0.5 self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device) self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device) @@ -108,7 +105,7 @@ class GravesAttention(nn.Module): mu_t = self.mu_prev + torch.nn.functional.softplus(k_t) g_t = torch.softmax(g_t, dim=-1) + self.eps - j = self.J[:inputs.size(1)+1] + j = self.J[: inputs.size(1) + 1] # attention weights phi_t = g_t.unsqueeze(-1) * (1 / (1 + torch.sigmoid((mu_t.unsqueeze(-1) - j) / sig_t.unsqueeze(-1)))) @@ -164,21 +161,29 @@ class OriginalAttention(nn.Module): trans_agent (bool): enable/disable transition agent in the forward attention. forward_attn_mask (int): enable/disable an explicit masking in forward attention. It is useful to set at especially inference time. """ + # Pylint gets confused by PyTorch conventions here - #pylint: disable=attribute-defined-outside-init - def __init__(self, query_dim, embedding_dim, attention_dim, - location_attention, attention_location_n_filters, - attention_location_kernel_size, windowing, norm, forward_attn, - trans_agent, forward_attn_mask): - super(OriginalAttention, self).__init__() - self.query_layer = Linear( - query_dim, attention_dim, bias=False, init_gain='tanh') - self.inputs_layer = Linear( - embedding_dim, attention_dim, bias=False, init_gain='tanh') + # pylint: disable=attribute-defined-outside-init + def __init__( + self, + query_dim, + embedding_dim, + attention_dim, + location_attention, + attention_location_n_filters, + attention_location_kernel_size, + windowing, + norm, + forward_attn, + trans_agent, + forward_attn_mask, + ): + super().__init__() + self.query_layer = Linear(query_dim, attention_dim, bias=False, init_gain="tanh") + self.inputs_layer = Linear(embedding_dim, attention_dim, bias=False, init_gain="tanh") self.v = Linear(attention_dim, 1, bias=True) if trans_agent: - self.ta = nn.Linear( - query_dim + embedding_dim, 1, bias=True) + self.ta = nn.Linear(query_dim + embedding_dim, 1, bias=True) if location_attention: self.location_layer = LocationLayer( attention_dim, @@ -202,9 +207,7 @@ class OriginalAttention(nn.Module): def init_forward_attn(self, inputs): B = inputs.shape[0] T = inputs.shape[1] - self.alpha = torch.cat( - [torch.ones([B, 1]), - torch.zeros([B, T])[:, :-1] + 1e-7], dim=1).to(inputs.device) + self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1] + 1e-7], dim=1).to(inputs.device) self.u = (0.5 * torch.ones([B, 1])).to(inputs.device) def init_location_attention(self, inputs): @@ -230,14 +233,10 @@ class OriginalAttention(nn.Module): self.attention_weights_cum += alignments def get_location_attention(self, query, processed_inputs): - attention_cat = torch.cat((self.attention_weights.unsqueeze(1), - self.attention_weights_cum.unsqueeze(1)), - dim=1) + attention_cat = torch.cat((self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1) processed_query = self.query_layer(query.unsqueeze(1)) processed_attention_weights = self.location_layer(attention_cat) - energies = self.v( - torch.tanh(processed_query + processed_attention_weights + - processed_inputs)) + energies = self.v(torch.tanh(processed_query + processed_attention_weights + processed_inputs)) energies = energies.squeeze(-1) return energies, processed_query @@ -264,24 +263,17 @@ class OriginalAttention(nn.Module): def apply_forward_attention(self, alignment): # forward attention - fwd_shifted_alpha = F.pad( - self.alpha[:, :-1].clone().to(alignment.device), (1, 0, 0, 0)) + fwd_shifted_alpha = F.pad(self.alpha[:, :-1].clone().to(alignment.device), (1, 0, 0, 0)) # compute transition potentials - alpha = ((1 - self.u) * self.alpha - + self.u * fwd_shifted_alpha - + 1e-8) * alignment + alpha = ((1 - self.u) * self.alpha + self.u * fwd_shifted_alpha + 1e-8) * alignment # force incremental alignment if not self.training and self.forward_attn_mask: _, n = fwd_shifted_alpha.max(1) val, _ = alpha.max(1) for b in range(alignment.shape[0]): - alpha[b, n[b] + 3:] = 0 - alpha[b, :( - n[b] - 1 - )] = 0 # ignore all previous states to prevent repetition. - alpha[b, - (n[b] - 2 - )] = 0.01 * val[b] # smoothing factor for the prev step + alpha[b, n[b] + 3 :] = 0 + alpha[b, : (n[b] - 1)] = 0 # ignore all previous states to prevent repetition. + alpha[b, (n[b] - 2)] = 0.01 * val[b] # smoothing factor for the prev step # renormalize attention weights alpha = alpha / alpha.sum(dim=1, keepdim=True) return alpha @@ -295,11 +287,9 @@ class OriginalAttention(nn.Module): mask: [B, T_en] """ if self.location_attention: - attention, _ = self.get_location_attention( - query, processed_inputs) + attention, _ = self.get_location_attention(query, processed_inputs) else: - attention, _ = self.get_attention( - query, processed_inputs) + attention, _ = self.get_attention(query, processed_inputs) # apply masking if mask is not None: attention.data.masked_fill_(~mask, self._mask_value) @@ -311,9 +301,7 @@ class OriginalAttention(nn.Module): if self.norm == "softmax": alignment = torch.softmax(attention, dim=-1) elif self.norm == "sigmoid": - alignment = torch.sigmoid(attention) / torch.sigmoid( - attention).sum( - dim=1, keepdim=True) + alignment = torch.sigmoid(attention) / torch.sigmoid(attention).sum(dim=1, keepdim=True) else: raise ValueError("Unknown value for attention norm type") @@ -367,19 +355,20 @@ class MonotonicDynamicConvolutionAttention(nn.Module): alpha (float, optional): [description]. Defaults to 0.1 from the paper. beta (float, optional): [description]. Defaults to 0.9 from the paper. """ + def __init__( - self, - query_dim, - embedding_dim, # pylint: disable=unused-argument - attention_dim, - static_filter_dim, - static_kernel_size, - dynamic_filter_dim, - dynamic_kernel_size, - prior_filter_len=11, - alpha=0.1, - beta=0.9, - ): + self, + query_dim, + embedding_dim, # pylint: disable=unused-argument + attention_dim, + static_filter_dim, + static_kernel_size, + dynamic_filter_dim, + dynamic_kernel_size, + prior_filter_len=11, + alpha=0.1, + beta=0.9, + ): super().__init__() self._mask_value = 1e-8 self.dynamic_filter_dim = dynamic_filter_dim @@ -388,9 +377,7 @@ class MonotonicDynamicConvolutionAttention(nn.Module): self.attention_weights = None # setup key and query layers self.query_layer = nn.Linear(query_dim, attention_dim) - self.key_layer = nn.Linear( - attention_dim, dynamic_filter_dim * dynamic_kernel_size, bias=False - ) + self.key_layer = nn.Linear(attention_dim, dynamic_filter_dim * dynamic_kernel_size, bias=False) self.static_filter_conv = nn.Conv1d( 1, static_filter_dim, @@ -402,8 +389,7 @@ class MonotonicDynamicConvolutionAttention(nn.Module): self.dynamic_filter_layer = nn.Linear(dynamic_filter_dim, attention_dim) self.v = nn.Linear(attention_dim, 1, bias=False) - prior = betabinom.pmf(range(prior_filter_len), prior_filter_len - 1, - alpha, beta) + prior = betabinom.pmf(range(prior_filter_len), prior_filter_len - 1, alpha, beta) self.register_buffer("prior", torch.FloatTensor(prior).flip(0)) # pylint: disable=unused-argument @@ -416,8 +402,8 @@ class MonotonicDynamicConvolutionAttention(nn.Module): """ # compute prior filters prior_filter = F.conv1d( - F.pad(self.attention_weights.unsqueeze(1), - (self.prior_filter_len - 1, 0)), self.prior.view(1, 1, -1)) + F.pad(self.attention_weights.unsqueeze(1), (self.prior_filter_len - 1, 0)), self.prior.view(1, 1, -1) + ) prior_filter = torch.log(prior_filter.clamp_min_(1e-6)).squeeze(1) G = self.key_layer(torch.tanh(self.query_layer(query))) # compute dynamic filters @@ -430,10 +416,12 @@ class MonotonicDynamicConvolutionAttention(nn.Module): dynamic_filter = dynamic_filter.view(query.size(0), self.dynamic_filter_dim, -1).transpose(1, 2) # compute static filters static_filter = self.static_filter_conv(self.attention_weights.unsqueeze(1)).transpose(1, 2) - alignment = self.v( - torch.tanh( - self.static_filter_layer(static_filter) + - self.dynamic_filter_layer(dynamic_filter))).squeeze(-1) + prior_filter + alignment = ( + self.v( + torch.tanh(self.static_filter_layer(static_filter) + self.dynamic_filter_layer(dynamic_filter)) + ).squeeze(-1) + + prior_filter + ) # compute attention weights attention_weights = F.softmax(alignment, dim=-1) # apply masking @@ -451,33 +439,52 @@ class MonotonicDynamicConvolutionAttention(nn.Module): B = inputs.size(0) T = inputs.size(1) self.attention_weights = torch.zeros([B, T], device=inputs.device) - self.attention_weights[:, 0] = 1. + self.attention_weights[:, 0] = 1.0 -def init_attn(attn_type, query_dim, embedding_dim, attention_dim, - location_attention, attention_location_n_filters, - attention_location_kernel_size, windowing, norm, forward_attn, - trans_agent, forward_attn_mask, attn_K): +def init_attn( + attn_type, + query_dim, + embedding_dim, + attention_dim, + location_attention, + attention_location_n_filters, + attention_location_kernel_size, + windowing, + norm, + forward_attn, + trans_agent, + forward_attn_mask, + attn_K, +): if attn_type == "original": - return OriginalAttention(query_dim, embedding_dim, attention_dim, - location_attention, - attention_location_n_filters, - attention_location_kernel_size, windowing, - norm, forward_attn, trans_agent, - forward_attn_mask) + return OriginalAttention( + query_dim, + embedding_dim, + attention_dim, + location_attention, + attention_location_n_filters, + attention_location_kernel_size, + windowing, + norm, + forward_attn, + trans_agent, + forward_attn_mask, + ) if attn_type == "graves": return GravesAttention(query_dim, attn_K) if attn_type == "dynamic_convolution": - return MonotonicDynamicConvolutionAttention(query_dim, - embedding_dim, - attention_dim, - static_filter_dim=8, - static_kernel_size=21, - dynamic_filter_dim=8, - dynamic_kernel_size=21, - prior_filter_len=11, - alpha=0.1, - beta=0.9) + return MonotonicDynamicConvolutionAttention( + query_dim, + embedding_dim, + attention_dim, + static_filter_dim=8, + static_kernel_size=21, + dynamic_filter_dim=8, + dynamic_kernel_size=21, + prior_filter_len=11, + alpha=0.1, + beta=0.9, + ) - raise RuntimeError( - " [!] Given Attention Type '{attn_type}' is not exist.") + raise RuntimeError(" [!] Given Attention Type '{attn_type}' is not exist.") diff --git a/TTS/tts/layers/tacotron/common_layers.py b/TTS/tts/layers/tacotron/common_layers.py index a23bb3f9..d3a9b80d 100644 --- a/TTS/tts/layers/tacotron/common_layers.py +++ b/TTS/tts/layers/tacotron/common_layers.py @@ -12,20 +12,14 @@ class Linear(nn.Module): bias (bool, optional): enable/disable bias in the layer. Defaults to True. init_gain (str, optional): method to compute the gain in the weight initializtion based on the nonlinear activation used afterwards. Defaults to 'linear'. """ - def __init__(self, - in_features, - out_features, - bias=True, - init_gain='linear'): - super(Linear, self).__init__() - self.linear_layer = torch.nn.Linear( - in_features, out_features, bias=bias) + + def __init__(self, in_features, out_features, bias=True, init_gain="linear"): + super().__init__() + self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias) self._init_w(init_gain) def _init_w(self, init_gain): - torch.nn.init.xavier_uniform_( - self.linear_layer.weight, - gain=torch.nn.init.calculate_gain(init_gain)) + torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain)) def forward(self, x): return self.linear_layer(x) @@ -42,21 +36,15 @@ class LinearBN(nn.Module): bias (bool, optional): enable/disable bias in the linear layer. Defaults to True. init_gain (str, optional): method to set the gain for weight initialization. Defaults to 'linear'. """ - def __init__(self, - in_features, - out_features, - bias=True, - init_gain='linear'): - super(LinearBN, self).__init__() - self.linear_layer = torch.nn.Linear( - in_features, out_features, bias=bias) + + def __init__(self, in_features, out_features, bias=True, init_gain="linear"): + super().__init__() + self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias) self.batch_normalization = nn.BatchNorm1d(out_features, momentum=0.1, eps=1e-5) self._init_w(init_gain) def _init_w(self, init_gain): - torch.nn.init.xavier_uniform_( - self.linear_layer.weight, - gain=torch.nn.init.calculate_gain(init_gain)) + torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain)) def forward(self, x): """ @@ -96,27 +84,21 @@ class Prenet(nn.Module): Defaults to [256, 256]. bias (bool, optional): enable/disable bias in prenet linear layers. Defaults to True. """ + # pylint: disable=dangerous-default-value - def __init__(self, - in_features, - prenet_type="original", - prenet_dropout=True, - out_features=[256, 256], - bias=True): - super(Prenet, self).__init__() + def __init__(self, in_features, prenet_type="original", prenet_dropout=True, out_features=[256, 256], bias=True): + super().__init__() self.prenet_type = prenet_type self.prenet_dropout = prenet_dropout in_features = [in_features] + out_features[:-1] if prenet_type == "bn": - self.linear_layers = nn.ModuleList([ - LinearBN(in_size, out_size, bias=bias) - for (in_size, out_size) in zip(in_features, out_features) - ]) + self.linear_layers = nn.ModuleList( + [LinearBN(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)] + ) elif prenet_type == "original": - self.linear_layers = nn.ModuleList([ - Linear(in_size, out_size, bias=bias) - for (in_size, out_size) in zip(in_features, out_features) - ]) + self.linear_layers = nn.ModuleList( + [Linear(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)] + ) def forward(self, x): for linear in self.linear_layers: diff --git a/TTS/tts/layers/tacotron/gst_layers.py b/TTS/tts/layers/tacotron/gst_layers.py index 63e76070..e2784e5d 100644 --- a/TTS/tts/layers/tacotron/gst_layers.py +++ b/TTS/tts/layers/tacotron/gst_layers.py @@ -11,8 +11,7 @@ class GST(nn.Module): def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, speaker_embedding_dim=None): super().__init__() self.encoder = ReferenceEncoder(num_mel, gst_embedding_dim) - self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, - gst_embedding_dim, speaker_embedding_dim) + self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, speaker_embedding_dim) def forward(self, inputs, speaker_embedding=None): enc_out = self.encoder(inputs) @@ -39,24 +38,17 @@ class ReferenceEncoder(nn.Module): num_layers = len(filters) - 1 convs = [ nn.Conv2d( - in_channels=filters[i], - out_channels=filters[i + 1], - kernel_size=(3, 3), - stride=(2, 2), - padding=(1, 1)) for i in range(num_layers) + in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1) + ) + for i in range(num_layers) ] self.convs = nn.ModuleList(convs) - self.bns = nn.ModuleList([ - nn.BatchNorm2d(num_features=filter_size) - for filter_size in filters[1:] - ]) + self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]]) - post_conv_height = self.calculate_post_conv_height( - num_mel, 3, 2, 1, num_layers) + post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 1, num_layers) self.recurrence = nn.GRU( - input_size=filters[-1] * post_conv_height, - hidden_size=embedding_dim // 2, - batch_first=True) + input_size=filters[-1] * post_conv_height, hidden_size=embedding_dim // 2, batch_first=True + ) def forward(self, inputs): batch_size = inputs.size(0) @@ -81,8 +73,7 @@ class ReferenceEncoder(nn.Module): return out.squeeze(0) @staticmethod - def calculate_post_conv_height(height, kernel_size, stride, pad, - n_convs): + def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs): """Height of spec after n convolutions with fixed kernel/stride/pad.""" for _ in range(n_convs): height = (height - kernel_size + 2 * pad) // stride + 1 @@ -92,8 +83,7 @@ class ReferenceEncoder(nn.Module): class StyleTokenLayer(nn.Module): """NN Module attending to style tokens based on prosody encodings.""" - def __init__(self, num_heads, num_style_tokens, - embedding_dim, speaker_embedding_dim=None): + def __init__(self, num_heads, num_style_tokens, embedding_dim, speaker_embedding_dim=None): super().__init__() self.query_dim = embedding_dim // 2 @@ -102,35 +92,31 @@ class StyleTokenLayer(nn.Module): self.query_dim += speaker_embedding_dim self.key_dim = embedding_dim // num_heads - self.style_tokens = nn.Parameter( - torch.FloatTensor(num_style_tokens, self.key_dim)) + self.style_tokens = nn.Parameter(torch.FloatTensor(num_style_tokens, self.key_dim)) nn.init.normal_(self.style_tokens, mean=0, std=0.5) self.attention = MultiHeadAttention( - query_dim=self.query_dim, - key_dim=self.key_dim, - num_units=embedding_dim, - num_heads=num_heads) + query_dim=self.query_dim, key_dim=self.key_dim, num_units=embedding_dim, num_heads=num_heads + ) def forward(self, inputs): batch_size = inputs.size(0) prosody_encoding = inputs.unsqueeze(1) # prosody_encoding: 3D tensor [batch_size, 1, encoding_size==128] - tokens = torch.tanh(self.style_tokens) \ - .unsqueeze(0) \ - .expand(batch_size, -1, -1) + tokens = torch.tanh(self.style_tokens).unsqueeze(0).expand(batch_size, -1, -1) # tokens: 3D tensor [batch_size, num tokens, token embedding size] style_embed = self.attention(prosody_encoding, tokens) return style_embed + class MultiHeadAttention(nn.Module): - ''' + """ input: query --- [N, T_q, query_dim] key --- [N, T_k, key_dim] output: out --- [N, T_q, num_units] - ''' + """ def __init__(self, query_dim, key_dim, num_units, num_heads): @@ -139,12 +125,9 @@ class MultiHeadAttention(nn.Module): self.num_heads = num_heads self.key_dim = key_dim - self.W_query = nn.Linear( - in_features=query_dim, out_features=num_units, bias=False) - self.W_key = nn.Linear( - in_features=key_dim, out_features=num_units, bias=False) - self.W_value = nn.Linear( - in_features=key_dim, out_features=num_units, bias=False) + self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False) + self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) + self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False) def forward(self, query, key): queries = self.W_query(query) # [N, T_q, num_units] @@ -152,25 +135,17 @@ class MultiHeadAttention(nn.Module): values = self.W_value(key) split_size = self.num_units // self.num_heads - queries = torch.stack( - torch.split(queries, split_size, dim=2), - dim=0) # [h, N, T_q, num_units/h] - keys = torch.stack( - torch.split(keys, split_size, dim=2), - dim=0) # [h, N, T_k, num_units/h] - values = torch.stack( - torch.split(values, split_size, dim=2), - dim=0) # [h, N, T_k, num_units/h] + queries = torch.stack(torch.split(queries, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h] + keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] + values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h] # score = softmax(QK^T / (d_k ** 0.5)) scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k] - scores = scores / (self.key_dim**0.5) + scores = scores / (self.key_dim ** 0.5) scores = F.softmax(scores, dim=3) # out = score * V out = torch.matmul(scores, values) # [h, N, T_q, num_units/h] - out = torch.cat( - torch.split(out, 1, dim=0), - dim=3).squeeze(0) # [N, T_q, num_units] + out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units] return out diff --git a/TTS/tts/layers/tacotron/tacotron.py b/TTS/tts/layers/tacotron/tacotron.py index c79edcc3..5ff9ed1d 100644 --- a/TTS/tts/layers/tacotron/tacotron.py +++ b/TTS/tts/layers/tacotron/tacotron.py @@ -23,24 +23,14 @@ class BatchNormConv1d(nn.Module): - output: (B, D) """ - def __init__(self, - in_channels, - out_channels, - kernel_size, - stride, - padding, - activation=None): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, activation=None): - super(BatchNormConv1d, self).__init__() + super().__init__() self.padding = padding self.padder = nn.ConstantPad1d(padding, 0) self.conv1d = nn.Conv1d( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - padding=0, - bias=False) + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0, bias=False + ) # Following tensorflow's default parameters self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3) self.activation = activation @@ -48,15 +38,14 @@ class BatchNormConv1d(nn.Module): def init_layers(self): if isinstance(self.activation, torch.nn.ReLU): - w_gain = 'relu' + w_gain = "relu" elif isinstance(self.activation, torch.nn.Tanh): - w_gain = 'tanh' + w_gain = "tanh" elif self.activation is None: - w_gain = 'linear' + w_gain = "linear" else: - raise RuntimeError('Unknown activation function') - torch.nn.init.xavier_uniform_( - self.conv1d.weight, gain=torch.nn.init.calculate_gain(w_gain)) + raise RuntimeError("Unknown activation function") + torch.nn.init.xavier_uniform_(self.conv1d.weight, gain=torch.nn.init.calculate_gain(w_gain)) def forward(self, x): x = self.padder(x) @@ -81,7 +70,7 @@ class Highway(nn.Module): # TODO: Try GLU layer def __init__(self, in_features, out_feature): - super(Highway, self).__init__() + super().__init__() self.H = nn.Linear(in_features, out_feature) self.H.bias.data.zero_() self.T = nn.Linear(in_features, out_feature) @@ -91,10 +80,8 @@ class Highway(nn.Module): # self.init_layers() def init_layers(self): - torch.nn.init.xavier_uniform_( - self.H.weight, gain=torch.nn.init.calculate_gain('relu')) - torch.nn.init.xavier_uniform_( - self.T.weight, gain=torch.nn.init.calculate_gain('sigmoid')) + torch.nn.init.xavier_uniform_(self.H.weight, gain=torch.nn.init.calculate_gain("relu")) + torch.nn.init.xavier_uniform_(self.T.weight, gain=torch.nn.init.calculate_gain("sigmoid")) def forward(self, inputs): H = self.relu(self.H(inputs)) @@ -104,30 +91,33 @@ class Highway(nn.Module): class CBHG(nn.Module): """CBHG module: a recurrent neural network composed of: - - 1-d convolution banks - - Highway networks + residual connections - - Bidirectional gated recurrent units + - 1-d convolution banks + - Highway networks + residual connections + - Bidirectional gated recurrent units - Args: - in_features (int): sample size - K (int): max filter size in conv bank - projections (list): conv channel sizes for conv projections - num_highways (int): number of highways layers + Args: + in_features (int): sample size + K (int): max filter size in conv bank + projections (list): conv channel sizes for conv projections + num_highways (int): number of highways layers - Shapes: - - input: (B, C, T_in) - - output: (B, T_in, C*2) + Shapes: + - input: (B, C, T_in) + - output: (B, T_in, C*2) """ - #pylint: disable=dangerous-default-value - def __init__(self, - in_features, - K=16, - conv_bank_features=128, - conv_projections=[128, 128], - highway_features=128, - gru_features=128, - num_highways=4): - super(CBHG, self).__init__() + + # pylint: disable=dangerous-default-value + def __init__( + self, + in_features, + K=16, + conv_bank_features=128, + conv_projections=[128, 128], + highway_features=128, + gru_features=128, + num_highways=4, + ): + super().__init__() self.in_features = in_features self.conv_bank_features = conv_bank_features self.highway_features = highway_features @@ -136,14 +126,19 @@ class CBHG(nn.Module): self.relu = nn.ReLU() # list of conv1d bank with filter size k=1...K # TODO: try dilational layers instead - self.conv1d_banks = nn.ModuleList([ - BatchNormConv1d(in_features, - conv_bank_features, - kernel_size=k, - stride=1, - padding=[(k - 1) // 2, k // 2], - activation=self.relu) for k in range(1, K + 1) - ]) + self.conv1d_banks = nn.ModuleList( + [ + BatchNormConv1d( + in_features, + conv_bank_features, + kernel_size=k, + stride=1, + padding=[(k - 1) // 2, k // 2], + activation=self.relu, + ) + for k in range(1, K + 1) + ] + ) # max pooling of conv bank, with padding # TODO: try average pooling OR larger kernel size out_features = [K * conv_bank_features] + conv_projections[:-1] @@ -151,31 +146,16 @@ class CBHG(nn.Module): activations += [None] # setup conv1d projection layers layer_set = [] - for (in_size, out_size, ac) in zip(out_features, conv_projections, - activations): - layer = BatchNormConv1d(in_size, - out_size, - kernel_size=3, - stride=1, - padding=[1, 1], - activation=ac) + for (in_size, out_size, ac) in zip(out_features, conv_projections, activations): + layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1, padding=[1, 1], activation=ac) layer_set.append(layer) self.conv1d_projections = nn.ModuleList(layer_set) # setup Highway layers if self.highway_features != conv_projections[-1]: - self.pre_highway = nn.Linear(conv_projections[-1], - highway_features, - bias=False) - self.highways = nn.ModuleList([ - Highway(highway_features, highway_features) - for _ in range(num_highways) - ]) + self.pre_highway = nn.Linear(conv_projections[-1], highway_features, bias=False) + self.highways = nn.ModuleList([Highway(highway_features, highway_features) for _ in range(num_highways)]) # bi-directional GPU layer - self.gru = nn.GRU(gru_features, - gru_features, - 1, - batch_first=True, - bidirectional=True) + self.gru = nn.GRU(gru_features, gru_features, 1, batch_first=True, bidirectional=True) def forward(self, inputs): # (B, in_features, T_in) @@ -210,7 +190,7 @@ class EncoderCBHG(nn.Module): r"""CBHG module with Encoder specific arguments""" def __init__(self): - super(EncoderCBHG, self).__init__() + super().__init__() self.cbhg = CBHG( 128, K=16, @@ -218,7 +198,8 @@ class EncoderCBHG(nn.Module): conv_projections=[128, 128], highway_features=128, gru_features=128, - num_highways=4) + num_highways=4, + ) def forward(self, x): return self.cbhg(x) @@ -235,7 +216,7 @@ class Encoder(nn.Module): """ def __init__(self, in_features): - super(Encoder, self).__init__() + super().__init__() self.prenet = Prenet(in_features, out_features=[256, 128]) self.cbhg = EncoderCBHG() @@ -248,7 +229,7 @@ class Encoder(nn.Module): class PostCBHG(nn.Module): def __init__(self, mel_dim): - super(PostCBHG, self).__init__() + super().__init__() self.cbhg = CBHG( mel_dim, K=8, @@ -256,7 +237,8 @@ class PostCBHG(nn.Module): conv_projections=[256, mel_dim], highway_features=128, gru_features=128, - num_highways=4) + num_highways=4, + ) def forward(self, x): return self.cbhg(x) @@ -289,11 +271,25 @@ class Decoder(nn.Module): # Pylint gets confused by PyTorch conventions here # pylint: disable=attribute-defined-outside-init - def __init__(self, in_channels, frame_channels, r, memory_size, attn_type, attn_windowing, - attn_norm, prenet_type, prenet_dropout, forward_attn, - trans_agent, forward_attn_mask, location_attn, attn_K, - separate_stopnet): - super(Decoder, self).__init__() + def __init__( + self, + in_channels, + frame_channels, + r, + memory_size, + attn_type, + attn_windowing, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + ): + super().__init__() self.r_init = r self.r = r self.in_channels = in_channels @@ -305,33 +301,30 @@ class Decoder(nn.Module): self.query_dim = 256 # memory -> |Prenet| -> processed_memory prenet_dim = frame_channels * self.memory_size if self.use_memory_queue else frame_channels - self.prenet = Prenet( - prenet_dim, - prenet_type, - prenet_dropout, - out_features=[256, 128]) + self.prenet = Prenet(prenet_dim, prenet_type, prenet_dropout, out_features=[256, 128]) # processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State # attention_rnn generates queries for the attention mechanism self.attention_rnn = nn.GRUCell(in_channels + 128, self.query_dim) - self.attention = init_attn(attn_type=attn_type, - query_dim=self.query_dim, - embedding_dim=in_channels, - attention_dim=128, - location_attention=location_attn, - attention_location_n_filters=32, - attention_location_kernel_size=31, - windowing=attn_windowing, - norm=attn_norm, - forward_attn=forward_attn, - trans_agent=trans_agent, - forward_attn_mask=forward_attn_mask, - attn_K=attn_K) + self.attention = init_attn( + attn_type=attn_type, + query_dim=self.query_dim, + embedding_dim=in_channels, + attention_dim=128, + location_attention=location_attn, + attention_location_n_filters=32, + attention_location_kernel_size=31, + windowing=attn_windowing, + norm=attn_norm, + forward_attn=forward_attn, + trans_agent=trans_agent, + forward_attn_mask=forward_attn_mask, + attn_K=attn_K, + ) # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input self.project_to_decoder_in = nn.Linear(256 + in_channels, 256) # decoder_RNN_input -> |RNN| -> RNN_state - self.decoder_rnns = nn.ModuleList( - [nn.GRUCell(256, 256) for _ in range(2)]) + self.decoder_rnns = nn.ModuleList([nn.GRUCell(256, 256) for _ in range(2)]) # RNN_state -> |Linear| -> mel_spec self.proj_to_mel = nn.Linear(256, frame_channels * self.r_init) # learn init values instead of zero init. @@ -364,8 +357,7 @@ class Decoder(nn.Module): # decoder states self.attention_rnn_hidden = torch.zeros(1, device=inputs.device).repeat(B, 256) self.decoder_rnn_hiddens = [ - torch.zeros(1, device=inputs.device).repeat(B, 256) - for idx in range(len(self.decoder_rnns)) + torch.zeros(1, device=inputs.device).repeat(B, 256) for idx in range(len(self.decoder_rnns)) ] self.context_vec = inputs.data.new(B, self.in_channels).zero_() # cache attention inputs @@ -376,8 +368,7 @@ class Decoder(nn.Module): attentions = torch.stack(attentions).transpose(0, 1) stop_tokens = torch.stack(stop_tokens).transpose(0, 1) outputs = torch.stack(outputs).transpose(0, 1).contiguous() - outputs = outputs.view( - outputs.size(0), -1, self.frame_channels) + outputs = outputs.view(outputs.size(0), -1, self.frame_channels) outputs = outputs.transpose(1, 2) return outputs, attentions, stop_tokens @@ -386,18 +377,15 @@ class Decoder(nn.Module): processed_memory = self.prenet(self.memory_input) # Attention RNN self.attention_rnn_hidden = self.attention_rnn( - torch.cat((processed_memory, self.context_vec), -1), - self.attention_rnn_hidden) - self.context_vec = self.attention( - self.attention_rnn_hidden, inputs, self.processed_inputs, mask) + torch.cat((processed_memory, self.context_vec), -1), self.attention_rnn_hidden + ) + self.context_vec = self.attention(self.attention_rnn_hidden, inputs, self.processed_inputs, mask) # Concat RNN output and attention context vector - decoder_input = self.project_to_decoder_in( - torch.cat((self.attention_rnn_hidden, self.context_vec), -1)) + decoder_input = self.project_to_decoder_in(torch.cat((self.attention_rnn_hidden, self.context_vec), -1)) # Pass through the decoder RNNs for idx in range(len(self.decoder_rnns)): - self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx]( - decoder_input, self.decoder_rnn_hiddens[idx]) + self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](decoder_input, self.decoder_rnn_hiddens[idx]) # Residual connection decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input decoder_output = decoder_input @@ -418,17 +406,17 @@ class Decoder(nn.Module): if self.use_memory_queue: if self.memory_size > self.r: # memory queue size is larger than number of frames per decoder iter - self.memory_input = torch.cat([ - new_memory, self.memory_input[:, :( - self.memory_size - self.r) * self.frame_channels].clone() - ], dim=-1) + self.memory_input = torch.cat( + [new_memory, self.memory_input[:, : (self.memory_size - self.r) * self.frame_channels].clone()], + dim=-1, + ) else: # memory queue size smaller than number of frames per decoder iter - self.memory_input = new_memory[:, :self.memory_size * self.frame_channels] + self.memory_input = new_memory[:, : self.memory_size * self.frame_channels] else: # use only the last frame prediction # assert new_memory.shape[-1] == self.r * self.frame_channels - self.memory_input = new_memory[:, self.frame_channels * (self.r - 1):] + self.memory_input = new_memory[:, self.frame_channels * (self.r - 1) :] def forward(self, inputs, memory, mask): """ @@ -487,8 +475,7 @@ class Decoder(nn.Module): attentions += [attention] stop_tokens += [stop_token] t += 1 - if t > inputs.shape[1] / 4 and (stop_token > 0.6 - or attention[:, -1].item() > 0.6): + if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6): break if t > self.max_decoder_steps: print(" | > Decoder stopped with 'max_decoder_steps") @@ -503,11 +490,10 @@ class StopNet(nn.Module): """ def __init__(self, in_features): - super(StopNet, self).__init__() + super().__init__() self.dropout = nn.Dropout(0.1) self.linear = nn.Linear(in_features, 1) - torch.nn.init.xavier_uniform_( - self.linear.weight, gain=torch.nn.init.calculate_gain('linear')) + torch.nn.init.xavier_uniform_(self.linear.weight, gain=torch.nn.init.calculate_gain("linear")) def forward(self, inputs): outputs = self.dropout(inputs) diff --git a/TTS/tts/layers/tacotron/tacotron2.py b/TTS/tts/layers/tacotron/tacotron2.py index 8e6dbc15..755598c6 100644 --- a/TTS/tts/layers/tacotron/tacotron2.py +++ b/TTS/tts/layers/tacotron/tacotron2.py @@ -5,8 +5,8 @@ from .common_layers import Prenet, Linear from .attentions import init_attn # NOTE: linter has a problem with the current TF release -#pylint: disable=no-value-for-parameter -#pylint: disable=unexpected-keyword-arg +# pylint: disable=no-value-for-parameter +# pylint: disable=unexpected-keyword-arg class ConvBNBlock(nn.Module): r"""Convolutions with Batch Normalization and non-linear activation. @@ -20,19 +20,17 @@ class ConvBNBlock(nn.Module): - input: (B, C_in, T) - output: (B, C_out, T) """ + def __init__(self, in_channels, out_channels, kernel_size, activation=None): - super(ConvBNBlock, self).__init__() + super().__init__() assert (kernel_size - 1) % 2 == 0 padding = (kernel_size - 1) // 2 - self.convolution1d = nn.Conv1d(in_channels, - out_channels, - kernel_size, - padding=padding) + self.convolution1d = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding) self.batch_normalization = nn.BatchNorm1d(out_channels, momentum=0.1, eps=1e-5) self.dropout = nn.Dropout(p=0.5) - if activation == 'relu': + if activation == "relu": self.activation = nn.ReLU() - elif activation == 'tanh': + elif activation == "tanh": self.activation = nn.Tanh() else: self.activation = nn.Identity() @@ -55,16 +53,14 @@ class Postnet(nn.Module): - input: (B, C_in, T) - output: (B, C_in, T) """ + def __init__(self, in_out_channels, num_convs=5): - super(Postnet, self).__init__() + super().__init__() self.convolutions = nn.ModuleList() - self.convolutions.append( - ConvBNBlock(in_out_channels, 512, kernel_size=5, activation='tanh')) + self.convolutions.append(ConvBNBlock(in_out_channels, 512, kernel_size=5, activation="tanh")) for _ in range(1, num_convs - 1): - self.convolutions.append( - ConvBNBlock(512, 512, kernel_size=5, activation='tanh')) - self.convolutions.append( - ConvBNBlock(512, in_out_channels, kernel_size=5, activation=None)) + self.convolutions.append(ConvBNBlock(512, 512, kernel_size=5, activation="tanh")) + self.convolutions.append(ConvBNBlock(512, in_out_channels, kernel_size=5, activation=None)) def forward(self, x): o = x @@ -83,18 +79,15 @@ class Encoder(nn.Module): - input: (B, C_in, T) - output: (B, C_in, T) """ + def __init__(self, in_out_channels=512): - super(Encoder, self).__init__() + super().__init__() self.convolutions = nn.ModuleList() for _ in range(3): - self.convolutions.append( - ConvBNBlock(in_out_channels, in_out_channels, 5, 'relu')) - self.lstm = nn.LSTM(in_out_channels, - int(in_out_channels / 2), - num_layers=1, - batch_first=True, - bias=True, - bidirectional=True) + self.convolutions.append(ConvBNBlock(in_out_channels, in_out_channels, 5, "relu")) + self.lstm = nn.LSTM( + in_out_channels, int(in_out_channels / 2), num_layers=1, batch_first=True, bias=True, bidirectional=True + ) self.rnn_state = None def forward(self, x, input_lengths): @@ -102,9 +95,7 @@ class Encoder(nn.Module): for layer in self.convolutions: o = layer(o) o = o.transpose(1, 2) - o = nn.utils.rnn.pack_padded_sequence(o, - input_lengths.cpu(), - batch_first=True) + o = nn.utils.rnn.pack_padded_sequence(o, input_lengths.cpu(), batch_first=True) self.lstm.flatten_parameters() o, _ = self.lstm(o) o, _ = nn.utils.rnn.pad_packed_sequence(o, batch_first=True) @@ -143,12 +134,27 @@ class Decoder(nn.Module): attn_K (int): number of attention heads for GravesAttention. separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow. """ + # Pylint gets confused by PyTorch conventions here - #pylint: disable=attribute-defined-outside-init - def __init__(self, in_channels, frame_channels, r, attn_type, attn_win, attn_norm, - prenet_type, prenet_dropout, forward_attn, trans_agent, - forward_attn_mask, location_attn, attn_K, separate_stopnet): - super(Decoder, self).__init__() + # pylint: disable=attribute-defined-outside-init + def __init__( + self, + in_channels, + frame_channels, + r, + attn_type, + attn_win, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + ): + super().__init__() self.frame_channels = frame_channels self.r_init = r self.r = r @@ -167,43 +173,36 @@ class Decoder(nn.Module): # memory -> |Prenet| -> processed_memory prenet_dim = self.frame_channels - self.prenet = Prenet(prenet_dim, - prenet_type, - prenet_dropout, - out_features=[self.prenet_dim, self.prenet_dim], - bias=False) + self.prenet = Prenet( + prenet_dim, prenet_type, prenet_dropout, out_features=[self.prenet_dim, self.prenet_dim], bias=False + ) - self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_channels, - self.query_dim, - bias=True) + self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_channels, self.query_dim, bias=True) - self.attention = init_attn(attn_type=attn_type, - query_dim=self.query_dim, - embedding_dim=in_channels, - attention_dim=128, - location_attention=location_attn, - attention_location_n_filters=32, - attention_location_kernel_size=31, - windowing=attn_win, - norm=attn_norm, - forward_attn=forward_attn, - trans_agent=trans_agent, - forward_attn_mask=forward_attn_mask, - attn_K=attn_K) + self.attention = init_attn( + attn_type=attn_type, + query_dim=self.query_dim, + embedding_dim=in_channels, + attention_dim=128, + location_attention=location_attn, + attention_location_n_filters=32, + attention_location_kernel_size=31, + windowing=attn_win, + norm=attn_norm, + forward_attn=forward_attn, + trans_agent=trans_agent, + forward_attn_mask=forward_attn_mask, + attn_K=attn_K, + ) - self.decoder_rnn = nn.LSTMCell(self.query_dim + in_channels, - self.decoder_rnn_dim, - bias=True) + self.decoder_rnn = nn.LSTMCell(self.query_dim + in_channels, self.decoder_rnn_dim, bias=True) - self.linear_projection = Linear(self.decoder_rnn_dim + in_channels, - self.frame_channels * self.r_init) + self.linear_projection = Linear(self.decoder_rnn_dim + in_channels, self.frame_channels * self.r_init) self.stopnet = nn.Sequential( nn.Dropout(0.1), - Linear(self.decoder_rnn_dim + self.frame_channels * self.r_init, - 1, - bias=True, - init_gain='sigmoid')) + Linear(self.decoder_rnn_dim + self.frame_channels * self.r_init, 1, bias=True, init_gain="sigmoid"), + ) self.memory_truncated = None def set_r(self, new_r): @@ -211,24 +210,18 @@ class Decoder(nn.Module): def get_go_frame(self, inputs): B = inputs.size(0) - memory = torch.zeros(1, device=inputs.device).repeat( - B, self.frame_channels * self.r) + memory = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels * self.r) return memory def _init_states(self, inputs, mask, keep_states=False): B = inputs.size(0) # T = inputs.size(1) if not keep_states: - self.query = torch.zeros(1, device=inputs.device).repeat( - B, self.query_dim) - self.attention_rnn_cell_state = torch.zeros( - 1, device=inputs.device).repeat(B, self.query_dim) - self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat( - B, self.decoder_rnn_dim) - self.decoder_cell = torch.zeros(1, device=inputs.device).repeat( - B, self.decoder_rnn_dim) - self.context = torch.zeros(1, device=inputs.device).repeat( - B, self.encoder_embedding_dim) + self.query = torch.zeros(1, device=inputs.device).repeat(B, self.query_dim) + self.attention_rnn_cell_state = torch.zeros(1, device=inputs.device).repeat(B, self.query_dim) + self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat(B, self.decoder_rnn_dim) + self.decoder_cell = torch.zeros(1, device=inputs.device).repeat(B, self.decoder_rnn_dim) + self.context = torch.zeros(1, device=inputs.device).repeat(B, self.encoder_embedding_dim) self.inputs = inputs self.processed_inputs = self.attention.preprocess_inputs(inputs) self.mask = mask @@ -254,38 +247,36 @@ class Decoder(nn.Module): def _update_memory(self, memory): if len(memory.shape) == 2: - return memory[:, self.frame_channels * (self.r - 1):] - return memory[:, :, self.frame_channels * (self.r - 1):] + return memory[:, self.frame_channels * (self.r - 1) :] + return memory[:, :, self.frame_channels * (self.r - 1) :] def decode(self, memory): - ''' - shapes: - - memory: B x r * self.frame_channels - ''' + """ + shapes: + - memory: B x r * self.frame_channels + """ # self.context: B x D_en # query_input: B x D_en + (r * self.frame_channels) query_input = torch.cat((memory, self.context), -1) # self.query and self.attention_rnn_cell_state : B x D_attn_rnn self.query, self.attention_rnn_cell_state = self.attention_rnn( - query_input, (self.query, self.attention_rnn_cell_state)) - self.query = F.dropout(self.query, self.p_attention_dropout, - self.training) + query_input, (self.query, self.attention_rnn_cell_state) + ) + self.query = F.dropout(self.query, self.p_attention_dropout, self.training) self.attention_rnn_cell_state = F.dropout( - self.attention_rnn_cell_state, self.p_attention_dropout, - self.training) + self.attention_rnn_cell_state, self.p_attention_dropout, self.training + ) # B x D_en - self.context = self.attention(self.query, self.inputs, - self.processed_inputs, self.mask) + self.context = self.attention(self.query, self.inputs, self.processed_inputs, self.mask) # B x (D_en + D_attn_rnn) decoder_rnn_input = torch.cat((self.query, self.context), -1) # self.decoder_hidden and self.decoder_cell: B x D_decoder_rnn self.decoder_hidden, self.decoder_cell = self.decoder_rnn( - decoder_rnn_input, (self.decoder_hidden, self.decoder_cell)) - self.decoder_hidden = F.dropout(self.decoder_hidden, - self.p_decoder_dropout, self.training) + decoder_rnn_input, (self.decoder_hidden, self.decoder_cell) + ) + self.decoder_hidden = F.dropout(self.decoder_hidden, self.p_decoder_dropout, self.training) # B x (D_decoder_rnn + D_en) - decoder_hidden_context = torch.cat((self.decoder_hidden, self.context), - dim=1) + decoder_hidden_context = torch.cat((self.decoder_hidden, self.context), dim=1) # B x (self.r * self.frame_channels) decoder_output = self.linear_projection(decoder_hidden_context) # B x (D_decoder_rnn + (self.r * self.frame_channels)) @@ -295,7 +286,7 @@ class Decoder(nn.Module): else: stop_token = self.stopnet(stopnet_input) # select outputs for the reduction rate self.r - decoder_output = decoder_output[:, :self.r * self.frame_channels] + decoder_output = decoder_output[:, : self.r * self.frame_channels] return decoder_output, self.attention.attention_weights, stop_token def forward(self, inputs, memories, mask): @@ -329,8 +320,7 @@ class Decoder(nn.Module): stop_tokens += [stop_token.squeeze(1)] alignments += [attention_weights] - outputs, stop_tokens, alignments = self._parse_outputs( - outputs, stop_tokens, alignments) + outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments) return outputs, alignments, stop_tokens def inference(self, inputs): @@ -369,8 +359,7 @@ class Decoder(nn.Module): memory = self._update_memory(decoder_output) t += 1 - outputs, stop_tokens, alignments = self._parse_outputs( - outputs, stop_tokens, alignments) + outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments) return outputs, alignments, stop_tokens @@ -404,8 +393,7 @@ class Decoder(nn.Module): self.memory_truncated = decoder_output t += 1 - outputs, stop_tokens, alignments = self._parse_outputs( - outputs, stop_tokens, alignments) + outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments) return outputs, alignments, stop_tokens diff --git a/TTS/tts/models/align_tts.py b/TTS/tts/models/align_tts.py index 16cb013a..903b99c8 100644 --- a/TTS/tts/models/align_tts.py +++ b/TTS/tts/models/align_tts.py @@ -62,39 +62,27 @@ class AlignTTS(nn.Module): # pylint: disable=dangerous-default-value def __init__( - self, - num_chars, - out_channels, - hidden_channels=256, - hidden_channels_dp=256, - encoder_type='fftransformer', - encoder_params={ - 'hidden_channels_ffn': 1024, - 'num_heads': 2, - 'num_layers': 6, - 'dropout_p': 0.1 - }, - decoder_type='fftransformer', - decoder_params={ - 'hidden_channels_ffn': 1024, - 'num_heads': 2, - 'num_layers': 6, - 'dropout_p': 0.1 - }, - length_scale=1, - num_speakers=0, - external_c=False, - c_in_channels=0): + self, + num_chars, + out_channels, + hidden_channels=256, + hidden_channels_dp=256, + encoder_type="fftransformer", + encoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}, + decoder_type="fftransformer", + decoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}, + length_scale=1, + num_speakers=0, + external_c=False, + c_in_channels=0, + ): super().__init__() - self.length_scale = float(length_scale) if isinstance( - length_scale, int) else length_scale + self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale self.emb = nn.Embedding(num_chars, hidden_channels) self.pos_encoder = PositionalEncoding(hidden_channels) - self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, - encoder_params, c_in_channels) - self.decoder = Decoder(out_channels, hidden_channels, decoder_type, - decoder_params) + self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels) + self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params) self.duration_predictor = DurationPredictor(hidden_channels_dp) self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1) @@ -111,13 +99,13 @@ class AlignTTS(nn.Module): @staticmethod def compute_log_probs(mu, log_sigma, y): # pylint: disable=protected-access, c-extension-no-member - y = y.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D] - mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] - log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] + y = y.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D] + mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] + log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D] expanded_y, expanded_mu = torch.broadcast_tensors(y, mu) - exponential = -0.5 * torch.mean(torch._C._nn.mse_loss( - expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2), - dim=-1) # B, L, T + exponential = -0.5 * torch.mean( + torch._C._nn.mse_loss(expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2), dim=-1 + ) # B, L, T logp = exponential - 0.5 * log_sigma.mean(dim=-1) return logp @@ -151,9 +139,7 @@ class AlignTTS(nn.Module): [1, 0, 0, 0, 0, 0, 0]] """ attn = self.convert_dr_to_align(dr, x_mask, y_mask) - o_en_ex = torch.matmul( - attn.squeeze(1).transpose(1, 2), en.transpose(1, - 2)).transpose(1, 2) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) return o_en_ex, attn def format_durations(self, o_dr_log, x_mask): @@ -170,12 +156,12 @@ class AlignTTS(nn.Module): def _sum_speaker_embedding(self, x, g): # project g to decoder dim. - if hasattr(self, 'proj_g'): + if hasattr(self, "proj_g"): g = self.proj_g(g) return x + g def _forward_encoder(self, x, x_lengths, g=None): - if hasattr(self, 'emb_g'): + if hasattr(self, "emb_g"): g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1] if g is not None: @@ -187,8 +173,7 @@ class AlignTTS(nn.Module): x_emb = torch.transpose(x_emb, 1, -1) # compute sequence masks - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), - 1).to(x.dtype) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) # encoder pass o_en = self.encoder(x_emb, x_mask) @@ -201,12 +186,11 @@ class AlignTTS(nn.Module): return o_en, o_en_dp, x_mask, g def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), - 1).to(o_en_dp.dtype) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) # expand o_en with durations o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) # positional encoding - if hasattr(self, 'pos_encoder'): + if hasattr(self, "pos_encoder"): o_en_ex = self.pos_encoder(o_en_ex, y_mask) # speaker embedding if g is not None: @@ -218,10 +202,8 @@ class AlignTTS(nn.Module): def _forward_mdn(self, o_en, y, y_lengths, x_mask): # MAS potentials and alignment mu, log_sigma = self.mdn_block(o_en) - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), - 1).to(o_en.dtype) - dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask, - y_mask) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) + dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask, y_mask) return dr_mas, mu, log_sigma, logp def forward(self, x, x_lengths, y, y_lengths, phase=None, g=None): # pylint: disable=unused-argument @@ -237,56 +219,31 @@ class AlignTTS(nn.Module): if phase == 0: # train encoder and MDN o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) - dr_mas, mu, log_sigma, logp = self._forward_mdn( - o_en, y, y_lengths, x_mask) - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), - 1).to(o_en_dp.dtype) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) attn = self.convert_dr_to_align(dr_mas, x_mask, y_mask) elif phase == 1: # train decoder o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask) - o_de, attn = self._forward_decoder(o_en.detach(), - o_en_dp.detach(), - dr_mas.detach(), - x_mask, - y_lengths, - g=g) + o_de, attn = self._forward_decoder(o_en.detach(), o_en_dp.detach(), dr_mas.detach(), x_mask, y_lengths, g=g) elif phase == 2: # train the whole except duration predictor o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) - dr_mas, mu, log_sigma, logp = self._forward_mdn( - o_en, y, y_lengths, x_mask) - o_de, attn = self._forward_decoder(o_en, - o_en_dp, - dr_mas, - x_mask, - y_lengths, - g=g) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) elif phase == 3: # train duration predictor o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) o_dr_log = self.duration_predictor(x, x_mask) - dr_mas, mu, log_sigma, logp = self._forward_mdn( - o_en, y, y_lengths, x_mask) - o_de, attn = self._forward_decoder(o_en, - o_en_dp, - dr_mas, - x_mask, - y_lengths, - g=g) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) o_dr_log = o_dr_log.squeeze(1) else: o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) - dr_mas, mu, log_sigma, logp = self._forward_mdn( - o_en, y, y_lengths, x_mask) - o_de, attn = self._forward_decoder(o_en, - o_en_dp, - dr_mas, - x_mask, - y_lengths, - g=g) + dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask) + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g) o_dr_log = o_dr_log.squeeze(1) dr_mas_log = torch.log(dr_mas + 1).squeeze(1) return o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp @@ -307,17 +264,14 @@ class AlignTTS(nn.Module): # duration predictor pass o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) y_lengths = o_dr.sum(1) - o_de, attn = self._forward_decoder(o_en, - o_en_dp, - o_dr, - x_mask, - y_lengths, - g=g) + o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) return o_de, attn - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) - self.load_state_dict(state['model']) + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) if eval: self.eval() assert not self.training diff --git a/TTS/tts/models/glow_tts.py b/TTS/tts/models/glow_tts.py index 2e01f87c..77222cba 100644 --- a/TTS/tts/models/glow_tts.py +++ b/TTS/tts/models/glow_tts.py @@ -34,28 +34,31 @@ class GlowTTS(nn.Module): encoder_params (dict): encoder module parameters. external_speaker_embedding_dim (int): channels of external speaker embedding vectors. """ - def __init__(self, - num_chars, - hidden_channels_enc, - hidden_channels_dec, - use_encoder_prenet, - hidden_channels_dp, - out_channels, - num_flow_blocks_dec=12, - kernel_size_dec=5, - dilation_rate=5, - num_block_layers=4, - dropout_p_dp=0.1, - dropout_p_dec=0.05, - num_speakers=0, - c_in_channels=0, - num_splits=4, - num_squeeze=1, - sigmoid_scale=False, - mean_only=False, - encoder_type="transformer", - encoder_params=None, - external_speaker_embedding_dim=None): + + def __init__( + self, + num_chars, + hidden_channels_enc, + hidden_channels_dec, + use_encoder_prenet, + hidden_channels_dp, + out_channels, + num_flow_blocks_dec=12, + kernel_size_dec=5, + dilation_rate=5, + num_block_layers=4, + dropout_p_dp=0.1, + dropout_p_dec=0.05, + num_speakers=0, + c_in_channels=0, + num_splits=4, + num_squeeze=1, + sigmoid_scale=False, + mean_only=False, + encoder_type="transformer", + encoder_params=None, + external_speaker_embedding_dim=None, + ): super().__init__() self.num_chars = num_chars @@ -78,7 +81,7 @@ class GlowTTS(nn.Module): # model constants. self.noise_scale = 0.33 # defines the noise variance applied to the random z vector at inference. - self.length_scale = 1. # scaler for the duration predictor. The larger it is, the slower the speech. + self.length_scale = 1.0 # scaler for the duration predictor. The larger it is, the slower the speech. self.external_speaker_embedding_dim = external_speaker_embedding_dim # if is a multispeaker and c_in_channels is 0, set to 256 @@ -88,28 +91,32 @@ class GlowTTS(nn.Module): elif self.external_speaker_embedding_dim: self.c_in_channels = self.external_speaker_embedding_dim - self.encoder = Encoder(num_chars, - out_channels=out_channels, - hidden_channels=hidden_channels_enc, - hidden_channels_dp=hidden_channels_dp, - encoder_type=encoder_type, - encoder_params=encoder_params, - mean_only=mean_only, - use_prenet=use_encoder_prenet, - dropout_p_dp=dropout_p_dp, - c_in_channels=self.c_in_channels) + self.encoder = Encoder( + num_chars, + out_channels=out_channels, + hidden_channels=hidden_channels_enc, + hidden_channels_dp=hidden_channels_dp, + encoder_type=encoder_type, + encoder_params=encoder_params, + mean_only=mean_only, + use_prenet=use_encoder_prenet, + dropout_p_dp=dropout_p_dp, + c_in_channels=self.c_in_channels, + ) - self.decoder = Decoder(out_channels, - hidden_channels_dec, - kernel_size_dec, - dilation_rate, - num_flow_blocks_dec, - num_block_layers, - dropout_p=dropout_p_dec, - num_splits=num_splits, - num_squeeze=num_squeeze, - sigmoid_scale=sigmoid_scale, - c_in_channels=self.c_in_channels) + self.decoder = Decoder( + out_channels, + hidden_channels_dec, + kernel_size_dec, + dilation_rate, + num_flow_blocks_dec, + num_block_layers, + dropout_p=dropout_p_dec, + num_splits=num_splits, + num_squeeze=num_squeeze, + sigmoid_scale=sigmoid_scale, + c_in_channels=self.c_in_channels, + ) if num_speakers > 1 and not external_speaker_embedding_dim: # speaker embedding layer @@ -119,12 +126,12 @@ class GlowTTS(nn.Module): @staticmethod def compute_outputs(attn, o_mean, o_log_scale, x_mask): # compute final values with the computed alignment - y_mean = torch.matmul( - attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose( - 1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] - y_log_scale = torch.matmul( - attn.squeeze(1).transpose(1, 2), o_log_scale.transpose( - 1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] + y_log_scale = torch.matmul(attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(1, 2)).transpose( + 1, 2 + ) # [b, t', t], [b, t, d] -> [b, d, t'] # compute total duration with adjustment o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask return y_mean, y_log_scale, o_attn_dur @@ -144,37 +151,27 @@ class GlowTTS(nn.Module): if self.external_speaker_embedding_dim: g = F.normalize(g).unsqueeze(-1) else: - g = F.normalize(self.emb_g(g)).unsqueeze(-1)# [b, h, 1] + g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1] # embedding pass - o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, - x_lengths, - g=g) + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # drop redisual frames wrt num_squeeze and set y_lengths. - y, y_lengths, y_max_length, attn = self.preprocess( - y, y_lengths, y_max_length, None) + y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None) # create masks - y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), - 1).to(x_mask.dtype) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # decoder pass z, logdet = self.decoder(y, y_mask, g=g, reverse=False) # find the alignment path with torch.no_grad(): o_scale = torch.exp(-2 * o_log_scale) - logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, - [1]).unsqueeze(-1) # [b, t, 1] - logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * - (z**2)) # [b, t, d] x [b, d, t'] = [b, t, t'] - logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), - z) # [b, t, d] x [b, d, t'] = [b, t, t'] - logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, - [1]).unsqueeze(-1) # [b, t, 1] + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1] + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t'] + logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1] logp = logp1 + logp2 + logp3 + logp4 # [b, t, t'] - attn = maximum_path(logp, - attn_mask.squeeze(1)).unsqueeze(1).detach() - y_mean, y_log_scale, o_attn_dur = self.compute_outputs( - attn, o_mean, o_log_scale, x_mask) + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() + y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) attn = attn.squeeze(1).permute(0, 2, 1) return z, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur @@ -187,26 +184,20 @@ class GlowTTS(nn.Module): g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h] # embedding pass - o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, - x_lengths, - g=g) + o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g) # compute output durations w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale w_ceil = torch.ceil(w) y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() y_max_length = None # compute masks - y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), - 1).to(x_mask.dtype) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype) attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) # compute attention mask - attn = generate_path(w_ceil.squeeze(1), - attn_mask.squeeze(1)).unsqueeze(1) - y_mean, y_log_scale, o_attn_dur = self.compute_outputs( - attn, o_mean, o_log_scale, x_mask) + attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) + y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask) - z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * - self.noise_scale) * y_mask + z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.noise_scale) * y_mask # decoder pass y, logdet = self.decoder(z, y_mask, g=g, reverse=True) attn = attn.squeeze(1).permute(0, 2, 1) @@ -224,9 +215,11 @@ class GlowTTS(nn.Module): def store_inverse(self): self.decoder.store_inverse() - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) - self.load_state_dict(state['model']) + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) if eval: self.eval() self.store_inverse() diff --git a/TTS/tts/models/speedy_speech.py b/TTS/tts/models/speedy_speech.py index 00cba5c7..0bad9ede 100644 --- a/TTS/tts/models/speedy_speech.py +++ b/TTS/tts/models/speedy_speech.py @@ -34,42 +34,37 @@ class SpeedySpeech(nn.Module): external_c (bool, optional): enable external speaker embeddings. Defaults to False. c_in_channels (int, optional): number of channels in speaker embedding vectors. Defaults to 0. """ + # pylint: disable=dangerous-default-value def __init__( - self, - num_chars, - out_channels, - hidden_channels, - positional_encoding=True, - length_scale=1, - encoder_type='residual_conv_bn', - encoder_params={ - "kernel_size": 4, - "dilations": 4 * [1, 2, 4] + [1], - "num_conv_blocks": 2, - "num_res_blocks": 13 - }, - decoder_type='residual_conv_bn', - decoder_params={ - "kernel_size": 4, - "dilations": 4 * [1, 2, 4, 8] + [1], - "num_conv_blocks": 2, - "num_res_blocks": 17 - }, - num_speakers=0, - external_c=False, - c_in_channels=0): + self, + num_chars, + out_channels, + hidden_channels, + positional_encoding=True, + length_scale=1, + encoder_type="residual_conv_bn", + encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13}, + decoder_type="residual_conv_bn", + decoder_params={ + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17, + }, + num_speakers=0, + external_c=False, + c_in_channels=0, + ): super().__init__() self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale self.emb = nn.Embedding(num_chars, hidden_channels) - self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, - encoder_params, c_in_channels) + self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels) if positional_encoding: self.pos_encoder = PositionalEncoding(hidden_channels) - self.decoder = Decoder(out_channels, hidden_channels, - decoder_type, decoder_params) + self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params) self.duration_predictor = DurationPredictor(hidden_channels + c_in_channels) if num_speakers > 1 and not external_c: @@ -97,9 +92,7 @@ class SpeedySpeech(nn.Module): """ attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype) - o_en_ex = torch.matmul( - attn.squeeze(1).transpose(1, 2), en.transpose(1, - 2)).transpose(1, 2) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) return o_en_ex, attn def format_durations(self, o_dr_log, x_mask): @@ -116,12 +109,12 @@ class SpeedySpeech(nn.Module): def _sum_speaker_embedding(self, x, g): # project g to decoder dim. - if hasattr(self, 'proj_g'): + if hasattr(self, "proj_g"): g = self.proj_g(g) return x + g def _forward_encoder(self, x, x_lengths, g=None): - if hasattr(self, 'emb_g'): + if hasattr(self, "emb_g"): g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1] if g is not None: @@ -133,8 +126,7 @@ class SpeedySpeech(nn.Module): x_emb = torch.transpose(x_emb, 1, -1) # compute sequence masks - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), - 1).to(x.dtype) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) # encoder pass o_en = self.encoder(x_emb, x_mask) @@ -147,12 +139,11 @@ class SpeedySpeech(nn.Module): return o_en, o_en_dp, x_mask, g def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), - 1).to(o_en_dp.dtype) + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) # expand o_en with durations o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) # positional encoding - if hasattr(self, 'pos_encoder'): + if hasattr(self, "pos_encoder"): o_en_ex = self.pos_encoder(o_en_ex, y_mask) # speaker embedding if g is not None: @@ -187,7 +178,7 @@ class SpeedySpeech(nn.Module): if x.shape[1] < 13: inference_padding += 13 - x.shape[1] # pad input to prevent dropping the last word - x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode='constant', value=0) + x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0) o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) # duration predictor pass o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) @@ -196,9 +187,11 @@ class SpeedySpeech(nn.Module): o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) return o_de, attn - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) - self.load_state_dict(state['model']) + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) if eval: self.eval() assert not self.training diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index 541c4159..0254149d 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -47,45 +47,67 @@ class Tacotron(TacotronAbstract): memory_size (int, optional): size of the history queue fed to the prenet. Model feeds the last ```memory_size``` output frames to the prenet. """ - def __init__(self, - num_chars, - num_speakers, - r=5, - postnet_output_dim=1025, - decoder_output_dim=80, - attn_type='original', - attn_win=False, - attn_norm="sigmoid", - prenet_type="original", - prenet_dropout=True, - forward_attn=False, - trans_agent=False, - forward_attn_mask=False, - location_attn=True, - attn_K=5, - separate_stopnet=True, - bidirectional_decoder=False, - double_decoder_consistency=False, - ddc_r=None, - encoder_in_features=256, - decoder_in_features=256, - speaker_embedding_dim=None, - gst=False, - gst_embedding_dim=256, - gst_num_heads=4, - gst_style_tokens=10, - memory_size=5, - gst_use_speaker_embedding=False): - super(Tacotron, - self).__init__(num_chars, num_speakers, r, postnet_output_dim, - decoder_output_dim, attn_type, attn_win, - attn_norm, prenet_type, prenet_dropout, - forward_attn, trans_agent, forward_attn_mask, - location_attn, attn_K, separate_stopnet, - bidirectional_decoder, double_decoder_consistency, - ddc_r, encoder_in_features, decoder_in_features, - speaker_embedding_dim, gst, gst_embedding_dim, - gst_num_heads, gst_style_tokens, gst_use_speaker_embedding) + + def __init__( + self, + num_chars, + num_speakers, + r=5, + postnet_output_dim=1025, + decoder_output_dim=80, + attn_type="original", + attn_win=False, + attn_norm="sigmoid", + prenet_type="original", + prenet_dropout=True, + forward_attn=False, + trans_agent=False, + forward_attn_mask=False, + location_attn=True, + attn_K=5, + separate_stopnet=True, + bidirectional_decoder=False, + double_decoder_consistency=False, + ddc_r=None, + encoder_in_features=256, + decoder_in_features=256, + speaker_embedding_dim=None, + gst=False, + gst_embedding_dim=256, + gst_num_heads=4, + gst_style_tokens=10, + memory_size=5, + gst_use_speaker_embedding=False, + ): + super().__init__( + num_chars, + num_speakers, + r, + postnet_output_dim, + decoder_output_dim, + attn_type, + attn_win, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + bidirectional_decoder, + double_decoder_consistency, + ddc_r, + encoder_in_features, + decoder_in_features, + speaker_embedding_dim, + gst, + gst_embedding_dim, + gst_num_heads, + gst_style_tokens, + gst_use_speaker_embedding, + ) # speaker embedding layers if self.num_speakers > 1: @@ -96,7 +118,7 @@ class Tacotron(TacotronAbstract): # speaker and gst embeddings is concat in decoder input if self.num_speakers > 1: - self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim + self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim # embedding layer self.embedding = nn.Embedding(num_chars, 256, padding_idx=0) @@ -104,32 +126,59 @@ class Tacotron(TacotronAbstract): # base model layers self.encoder = Encoder(self.encoder_in_features) - self.decoder = Decoder(self.decoder_in_features, decoder_output_dim, r, - memory_size, attn_type, attn_win, attn_norm, - prenet_type, prenet_dropout, forward_attn, - trans_agent, forward_attn_mask, location_attn, - attn_K, separate_stopnet) + self.decoder = Decoder( + self.decoder_in_features, + decoder_output_dim, + r, + memory_size, + attn_type, + attn_win, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + ) self.postnet = PostCBHG(decoder_output_dim) - self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, - postnet_output_dim) + self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim) # global style token layers if self.gst: - self.gst_layer = GST(num_mel=80, - num_heads=gst_num_heads, - num_style_tokens=gst_style_tokens, - gst_embedding_dim=self.gst_embedding_dim, - speaker_embedding_dim=speaker_embedding_dim if self.embeddings_per_sample and self.gst_use_speaker_embedding else None) + self.gst_layer = GST( + num_mel=80, + num_heads=gst_num_heads, + num_style_tokens=gst_style_tokens, + gst_embedding_dim=self.gst_embedding_dim, + speaker_embedding_dim=speaker_embedding_dim + if self.embeddings_per_sample and self.gst_use_speaker_embedding + else None, + ) # backward pass decoder if self.bidirectional_decoder: self._init_backward_decoder() # setup DDC if self.double_decoder_consistency: self.coarse_decoder = Decoder( - self.decoder_in_features, decoder_output_dim, ddc_r, memory_size, - attn_type, attn_win, attn_norm, prenet_type, prenet_dropout, - forward_attn, trans_agent, forward_attn_mask, location_attn, - attn_K, separate_stopnet) + self.decoder_in_features, + decoder_output_dim, + ddc_r, + memory_size, + attn_type, + attn_win, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + ) def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None, speaker_embeddings=None): """ @@ -151,9 +200,9 @@ class Tacotron(TacotronAbstract): # global style token if self.gst: # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, - mel_specs, - speaker_embeddings if self.gst_use_speaker_embedding else None) + encoder_outputs = self.compute_gst( + encoder_outputs, mel_specs, speaker_embeddings if self.gst_use_speaker_embedding else None + ) # speaker embedding if self.num_speakers > 1: if not self.embeddings_per_sample: @@ -166,8 +215,7 @@ class Tacotron(TacotronAbstract): # decoder_outputs: B x decoder_in_features x T_out # alignments: B x T_in x encoder_in_features # stop_tokens: B x T_in - decoder_outputs, alignments, stop_tokens = self.decoder( - encoder_outputs, mel_specs, input_mask) + decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask) # sequence masking if output_mask is not None: decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs) @@ -182,10 +230,26 @@ class Tacotron(TacotronAbstract): decoder_outputs = decoder_outputs.transpose(1, 2).contiguous() if self.bidirectional_decoder: decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask) - return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward + return ( + decoder_outputs, + postnet_outputs, + alignments, + stop_tokens, + decoder_outputs_backward, + alignments_backward, + ) if self.double_decoder_consistency: - decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(mel_specs, encoder_outputs, alignments, input_mask) - return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward + decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass( + mel_specs, encoder_outputs, alignments, input_mask + ) + return ( + decoder_outputs, + postnet_outputs, + alignments, + stop_tokens, + decoder_outputs_backward, + alignments_backward, + ) return decoder_outputs, postnet_outputs, alignments, stop_tokens @torch.no_grad() @@ -194,9 +258,9 @@ class Tacotron(TacotronAbstract): encoder_outputs = self.encoder(inputs) if self.gst: # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, - style_mel, - speaker_embeddings if self.gst_use_speaker_embedding else None) + encoder_outputs = self.compute_gst( + encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None + ) if self.num_speakers > 1: if not self.embeddings_per_sample: # B x 1 x speaker_embed_dim @@ -205,8 +269,7 @@ class Tacotron(TacotronAbstract): # B x 1 x speaker_embed_dim speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1) encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) - decoder_outputs, alignments, stop_tokens = self.decoder.inference( - encoder_outputs) + decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs) postnet_outputs = self.postnet(decoder_outputs) postnet_outputs = self.last_linear(postnet_outputs) decoder_outputs = decoder_outputs.transpose(1, 2) diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index 0e751c32..25386fca 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -44,44 +44,66 @@ class Tacotron2(TacotronAbstract): gst_style_tokens (int, optional): number of GST tokens. Defaults to 10. gst_use_speaker_embedding (bool, optional): enable/disable inputing speaker embedding to GST. Defaults to False. """ - def __init__(self, - num_chars, - num_speakers, - r, - postnet_output_dim=80, - decoder_output_dim=80, - attn_type='original', - attn_win=False, - attn_norm="softmax", - prenet_type="original", - prenet_dropout=True, - forward_attn=False, - trans_agent=False, - forward_attn_mask=False, - location_attn=True, - attn_K=5, - separate_stopnet=True, - bidirectional_decoder=False, - double_decoder_consistency=False, - ddc_r=None, - encoder_in_features=512, - decoder_in_features=512, - speaker_embedding_dim=None, - gst=False, - gst_embedding_dim=512, - gst_num_heads=4, - gst_style_tokens=10, - gst_use_speaker_embedding=False): - super(Tacotron2, - self).__init__(num_chars, num_speakers, r, postnet_output_dim, - decoder_output_dim, attn_type, attn_win, - attn_norm, prenet_type, prenet_dropout, - forward_attn, trans_agent, forward_attn_mask, - location_attn, attn_K, separate_stopnet, - bidirectional_decoder, double_decoder_consistency, - ddc_r, encoder_in_features, decoder_in_features, - speaker_embedding_dim, gst, gst_embedding_dim, - gst_num_heads, gst_style_tokens, gst_use_speaker_embedding) + + def __init__( + self, + num_chars, + num_speakers, + r, + postnet_output_dim=80, + decoder_output_dim=80, + attn_type="original", + attn_win=False, + attn_norm="softmax", + prenet_type="original", + prenet_dropout=True, + forward_attn=False, + trans_agent=False, + forward_attn_mask=False, + location_attn=True, + attn_K=5, + separate_stopnet=True, + bidirectional_decoder=False, + double_decoder_consistency=False, + ddc_r=None, + encoder_in_features=512, + decoder_in_features=512, + speaker_embedding_dim=None, + gst=False, + gst_embedding_dim=512, + gst_num_heads=4, + gst_style_tokens=10, + gst_use_speaker_embedding=False, + ): + super().__init__( + num_chars, + num_speakers, + r, + postnet_output_dim, + decoder_output_dim, + attn_type, + attn_win, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + bidirectional_decoder, + double_decoder_consistency, + ddc_r, + encoder_in_features, + decoder_in_features, + speaker_embedding_dim, + gst, + gst_embedding_dim, + gst_num_heads, + gst_style_tokens, + gst_use_speaker_embedding, + ) # speaker embedding layer if self.num_speakers > 1: @@ -92,36 +114,63 @@ class Tacotron2(TacotronAbstract): # speaker and gst embeddings is concat in decoder input if self.num_speakers > 1: - self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim + self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim # embedding layer self.embedding = nn.Embedding(num_chars, 512, padding_idx=0) # base model layers self.encoder = Encoder(self.encoder_in_features) - self.decoder = Decoder(self.decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win, - attn_norm, prenet_type, prenet_dropout, - forward_attn, trans_agent, forward_attn_mask, - location_attn, attn_K, separate_stopnet) + self.decoder = Decoder( + self.decoder_in_features, + self.decoder_output_dim, + r, + attn_type, + attn_win, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + ) self.postnet = Postnet(self.postnet_output_dim) # global style token layers if self.gst: - self.gst_layer = GST(num_mel=80, - num_heads=self.gst_num_heads, - num_style_tokens=self.gst_style_tokens, - gst_embedding_dim=self.gst_embedding_dim, - speaker_embedding_dim=speaker_embedding_dim if self.embeddings_per_sample and self.gst_use_speaker_embedding else None) + self.gst_layer = GST( + num_mel=80, + num_heads=self.gst_num_heads, + num_style_tokens=self.gst_style_tokens, + gst_embedding_dim=self.gst_embedding_dim, + speaker_embedding_dim=speaker_embedding_dim + if self.embeddings_per_sample and self.gst_use_speaker_embedding + else None, + ) # backward pass decoder if self.bidirectional_decoder: self._init_backward_decoder() # setup DDC if self.double_decoder_consistency: self.coarse_decoder = Decoder( - self.decoder_in_features, self.decoder_output_dim, ddc_r, attn_type, - attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, - trans_agent, forward_attn_mask, location_attn, attn_K, - separate_stopnet) + self.decoder_in_features, + self.decoder_output_dim, + ddc_r, + attn_type, + attn_win, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + attn_K, + separate_stopnet, + ) @staticmethod def shape_outputs(mel_outputs, mel_outputs_postnet, alignments): @@ -148,9 +197,9 @@ class Tacotron2(TacotronAbstract): encoder_outputs = self.encoder(embedded_inputs, text_lengths) if self.gst: # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, - mel_specs, - speaker_embeddings if self.gst_use_speaker_embedding else None) + encoder_outputs = self.compute_gst( + encoder_outputs, mel_specs, speaker_embeddings if self.gst_use_speaker_embedding else None + ) if self.num_speakers > 1: if not self.embeddings_per_sample: # B x 1 x speaker_embed_dim @@ -163,8 +212,7 @@ class Tacotron2(TacotronAbstract): encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs) # B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r - decoder_outputs, alignments, stop_tokens = self.decoder( - encoder_outputs, mel_specs, input_mask) + decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask) # sequence masking if mel_lengths is not None: decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs) @@ -175,14 +223,29 @@ class Tacotron2(TacotronAbstract): if output_mask is not None: postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs) # B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in - decoder_outputs, postnet_outputs, alignments = self.shape_outputs( - decoder_outputs, postnet_outputs, alignments) + decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments) if self.bidirectional_decoder: decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask) - return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward + return ( + decoder_outputs, + postnet_outputs, + alignments, + stop_tokens, + decoder_outputs_backward, + alignments_backward, + ) if self.double_decoder_consistency: - decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(mel_specs, encoder_outputs, alignments, input_mask) - return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward + decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass( + mel_specs, encoder_outputs, alignments, input_mask + ) + return ( + decoder_outputs, + postnet_outputs, + alignments, + stop_tokens, + decoder_outputs_backward, + alignments_backward, + ) return decoder_outputs, postnet_outputs, alignments, stop_tokens @torch.no_grad() @@ -192,20 +255,18 @@ class Tacotron2(TacotronAbstract): if self.gst: # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, - style_mel, - speaker_embeddings if self.gst_use_speaker_embedding else None) + encoder_outputs = self.compute_gst( + encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None + ) if self.num_speakers > 1: if not self.embeddings_per_sample: speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) - decoder_outputs, alignments, stop_tokens = self.decoder.inference( - encoder_outputs) + decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs) postnet_outputs = self.postnet(decoder_outputs) postnet_outputs = decoder_outputs + postnet_outputs - decoder_outputs, postnet_outputs, alignments = self.shape_outputs( - decoder_outputs, postnet_outputs, alignments) + decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments) return decoder_outputs, postnet_outputs, alignments, stop_tokens def inference_truncated(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None): @@ -217,19 +278,17 @@ class Tacotron2(TacotronAbstract): if self.gst: # B x gst_dim - encoder_outputs = self.compute_gst(encoder_outputs, - style_mel, - speaker_embeddings if self.gst_use_speaker_embedding else None) + encoder_outputs = self.compute_gst( + encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None + ) if self.num_speakers > 1: if not self.embeddings_per_sample: speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None] encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings) - mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated( - encoder_outputs) + mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(encoder_outputs) mel_outputs_postnet = self.postnet(mel_outputs) mel_outputs_postnet = mel_outputs + mel_outputs_postnet - mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs( - mel_outputs, mel_outputs_postnet, alignments) + mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(mel_outputs, mel_outputs_postnet, alignments) return mel_outputs, mel_outputs_postnet, alignments, stop_tokens diff --git a/TTS/tts/models/tacotron_abstract.py b/TTS/tts/models/tacotron_abstract.py index 22e86ee4..820dd8b8 100644 --- a/TTS/tts/models/tacotron_abstract.py +++ b/TTS/tts/models/tacotron_abstract.py @@ -8,34 +8,36 @@ from TTS.tts.utils.generic_utils import sequence_mask class TacotronAbstract(ABC, nn.Module): - def __init__(self, - num_chars, - num_speakers, - r, - postnet_output_dim=80, - decoder_output_dim=80, - attn_type='original', - attn_win=False, - attn_norm="softmax", - prenet_type="original", - prenet_dropout=True, - forward_attn=False, - trans_agent=False, - forward_attn_mask=False, - location_attn=True, - attn_K=5, - separate_stopnet=True, - bidirectional_decoder=False, - double_decoder_consistency=False, - ddc_r=None, - encoder_in_features=512, - decoder_in_features=512, - speaker_embedding_dim=None, - gst=False, - gst_embedding_dim=512, - gst_num_heads=4, - gst_style_tokens=10, - gst_use_speaker_embedding=False): + def __init__( + self, + num_chars, + num_speakers, + r, + postnet_output_dim=80, + decoder_output_dim=80, + attn_type="original", + attn_win=False, + attn_norm="softmax", + prenet_type="original", + prenet_dropout=True, + forward_attn=False, + trans_agent=False, + forward_attn_mask=False, + location_attn=True, + attn_K=5, + separate_stopnet=True, + bidirectional_decoder=False, + double_decoder_consistency=False, + ddc_r=None, + encoder_in_features=512, + decoder_in_features=512, + speaker_embedding_dim=None, + gst=False, + gst_embedding_dim=512, + gst_num_heads=4, + gst_style_tokens=10, + gst_use_speaker_embedding=False, + ): """ Abstract Tacotron class """ super().__init__() self.num_chars = num_chars @@ -82,7 +84,7 @@ class TacotronAbstract(ABC, nn.Module): # global style token if self.gst: - self.decoder_in_features += gst_embedding_dim # add gst embedding dim + self.decoder_in_features += gst_embedding_dim # add gst embedding dim self.gst_layer = None # model states @@ -121,10 +123,12 @@ class TacotronAbstract(ABC, nn.Module): def inference(self): pass - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) - self.load_state_dict(state['model']) - self.decoder.set_r(state['r']) + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) + self.decoder.set_r(state["r"]) if eval: self.eval() assert not self.training @@ -149,25 +153,24 @@ class TacotronAbstract(ABC, nn.Module): def _backward_pass(self, mel_specs, encoder_outputs, mask): """ Run backwards decoder """ decoder_outputs_b, alignments_b, _ = self.decoder_backward( - encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask) + encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask + ) decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous() return decoder_outputs_b, alignments_b - def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, - input_mask): + def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask): """ Double Decoder Consistency """ T = mel_specs.shape[1] if T % self.coarse_decoder.r > 0: padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r) - mel_specs = torch.nn.functional.pad(mel_specs, - (0, 0, 0, padding_size, 0, 0)) + mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0)) decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder( - encoder_outputs.detach(), mel_specs, input_mask) + encoder_outputs.detach(), mel_specs, input_mask + ) # scale_factor = self.decoder.r_init / self.decoder.r alignments_backward = torch.nn.functional.interpolate( - alignments_backward.transpose(1, 2), - size=alignments.shape[1], - mode='nearest').transpose(1, 2) + alignments_backward.transpose(1, 2), size=alignments.shape[1], mode="nearest" + ).transpose(1, 2) decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2) decoder_outputs_backward = decoder_outputs_backward[:, :T, :] return decoder_outputs_backward, alignments_backward @@ -179,20 +182,17 @@ class TacotronAbstract(ABC, nn.Module): def compute_speaker_embedding(self, speaker_ids): """ Compute speaker embedding vectors """ if hasattr(self, "speaker_embedding") and speaker_ids is None: - raise RuntimeError( - " [!] Model has speaker embedding layer but speaker_id is not provided" - ) + raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided") if hasattr(self, "speaker_embedding") and speaker_ids is not None: self.speaker_embeddings = self.speaker_embedding(speaker_ids).unsqueeze(1) if hasattr(self, "speaker_project_mel") and speaker_ids is not None: - self.speaker_embeddings_projected = self.speaker_project_mel( - self.speaker_embeddings).squeeze(1) + self.speaker_embeddings_projected = self.speaker_project_mel(self.speaker_embeddings).squeeze(1) def compute_gst(self, inputs, style_input, speaker_embedding=None): """ Compute global style token """ device = inputs.device if isinstance(style_input, dict): - query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device) + query = torch.zeros(1, 1, self.gst_embedding_dim // 2).to(device) if speaker_embedding is not None: query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1) @@ -205,20 +205,18 @@ class TacotronAbstract(ABC, nn.Module): elif style_input is None: gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device) else: - gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable + gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable inputs = self._concat_speaker_embedding(inputs, gst_outputs) return inputs @staticmethod def _add_speaker_embedding(outputs, speaker_embeddings): - speaker_embeddings_ = speaker_embeddings.expand( - outputs.size(0), outputs.size(1), -1) + speaker_embeddings_ = speaker_embeddings.expand(outputs.size(0), outputs.size(1), -1) outputs = outputs + speaker_embeddings_ return outputs @staticmethod def _concat_speaker_embedding(outputs, speaker_embeddings): - speaker_embeddings_ = speaker_embeddings.expand( - outputs.size(0), outputs.size(1), -1) + speaker_embeddings_ = speaker_embeddings.expand(outputs.size(0), outputs.size(1), -1) outputs = torch.cat([outputs, speaker_embeddings_], dim=-1) return outputs diff --git a/TTS/tts/tf/layers/tacotron/common_layers.py b/TTS/tts/tf/layers/tacotron/common_layers.py index ad18b9fc..886f0e61 100644 --- a/TTS/tts/tf/layers/tacotron/common_layers.py +++ b/TTS/tts/tf/layers/tacotron/common_layers.py @@ -1,16 +1,18 @@ import tensorflow as tf from tensorflow import keras from tensorflow.python.ops import math_ops + # from tensorflow_addons.seq2seq import BahdanauAttention # NOTE: linter has a problem with the current TF release -#pylint: disable=no-value-for-parameter -#pylint: disable=unexpected-keyword-arg +# pylint: disable=no-value-for-parameter +# pylint: disable=unexpected-keyword-arg + class Linear(keras.layers.Layer): def __init__(self, units, use_bias, **kwargs): - super(Linear, self).__init__(**kwargs) - self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer') + super().__init__(**kwargs) + self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name="linear_layer") self.activation = keras.layers.ReLU() def call(self, x): @@ -23,9 +25,11 @@ class Linear(keras.layers.Layer): class LinearBN(keras.layers.Layer): def __init__(self, units, use_bias, **kwargs): - super(LinearBN, self).__init__(**kwargs) - self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer') - self.batch_normalization = keras.layers.BatchNormalization(axis=-1, momentum=0.90, epsilon=1e-5, name='batch_normalization') + super().__init__(**kwargs) + self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name="linear_layer") + self.batch_normalization = keras.layers.BatchNormalization( + axis=-1, momentum=0.90, epsilon=1e-5, name="batch_normalization" + ) self.activation = keras.layers.ReLU() def call(self, x, training=None): @@ -39,22 +43,21 @@ class LinearBN(keras.layers.Layer): class Prenet(keras.layers.Layer): - def __init__(self, - prenet_type, - prenet_dropout, - units, - bias, - **kwargs): - super(Prenet, self).__init__(**kwargs) + def __init__(self, prenet_type, prenet_dropout, units, bias, **kwargs): + super().__init__(**kwargs) self.prenet_type = prenet_type self.prenet_dropout = prenet_dropout self.linear_layers = [] if prenet_type == "bn": - self.linear_layers += [LinearBN(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)] + self.linear_layers += [ + LinearBN(unit, use_bias=bias, name=f"linear_layer_{idx}") for idx, unit in enumerate(units) + ] elif prenet_type == "original": - self.linear_layers += [Linear(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)] + self.linear_layers += [ + Linear(unit, use_bias=bias, name=f"linear_layer_{idx}") for idx, unit in enumerate(units) + ] else: - raise RuntimeError(' [!] Unknown prenet type.') + raise RuntimeError(" [!] Unknown prenet type.") if prenet_dropout: self.dropout = keras.layers.Dropout(rate=0.5) @@ -80,11 +83,22 @@ def _sigmoid_norm(score): class Attention(keras.layers.Layer): """TODO: implement forward_attention TODO: location sensitive attention - TODO: implement attention windowing """ - def __init__(self, attn_dim, use_loc_attn, loc_attn_n_filters, - loc_attn_kernel_size, use_windowing, norm, use_forward_attn, - use_trans_agent, use_forward_attn_mask, **kwargs): - super(Attention, self).__init__(**kwargs) + TODO: implement attention windowing""" + + def __init__( + self, + attn_dim, + use_loc_attn, + loc_attn_n_filters, + loc_attn_kernel_size, + use_windowing, + norm, + use_forward_attn, + use_trans_agent, + use_forward_attn_mask, + **kwargs, + ): + super().__init__(**kwargs) self.use_loc_attn = use_loc_attn self.loc_attn_n_filters = loc_attn_n_filters self.loc_attn_kernel_size = loc_attn_kernel_size @@ -93,20 +107,23 @@ class Attention(keras.layers.Layer): self.use_forward_attn = use_forward_attn self.use_trans_agent = use_trans_agent self.use_forward_attn_mask = use_forward_attn_mask - self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name='query_layer/linear_layer') - self.inputs_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name=f'{self.name}/inputs_layer/linear_layer') - self.v = tf.keras.layers.Dense(1, use_bias=True, name='v/linear_layer') + self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name="query_layer/linear_layer") + self.inputs_layer = tf.keras.layers.Dense( + attn_dim, use_bias=False, name=f"{self.name}/inputs_layer/linear_layer" + ) + self.v = tf.keras.layers.Dense(1, use_bias=True, name="v/linear_layer") if use_loc_attn: self.location_conv1d = keras.layers.Conv1D( filters=loc_attn_n_filters, kernel_size=loc_attn_kernel_size, - padding='same', + padding="same", use_bias=False, - name='location_layer/location_conv1d') - self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name='location_layer/location_dense') - if norm == 'softmax': + name="location_layer/location_conv1d", + ) + self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name="location_layer/location_dense") + if norm == "softmax": self.norm_func = tf.nn.softmax - elif norm == 'sigmoid': + elif norm == "sigmoid": self.norm_func = _sigmoid_norm else: raise ValueError("Unknown value for attention norm type") @@ -118,30 +135,25 @@ class Attention(keras.layers.Layer): attention_old = tf.zeros([batch_size, value_length]) states = [attention_cum, attention_old] if self.use_forward_attn: - alpha = tf.concat([ - tf.ones([batch_size, 1]), - tf.zeros([batch_size, value_length])[:, :-1] + 1e-7 - ], 1) + alpha = tf.concat([tf.ones([batch_size, 1]), tf.zeros([batch_size, value_length])[:, :-1] + 1e-7], 1) states.append(alpha) return tuple(states) def process_values(self, values): """ cache values for decoder iterations """ - #pylint: disable=attribute-defined-outside-init + # pylint: disable=attribute-defined-outside-init self.processed_values = self.inputs_layer(values) self.values = values def get_loc_attn(self, query, states): - """ compute location attention, query layer and + """compute location attention, query layer and unnorm. attention weights""" attention_cum, attention_old = states[:2] attn_cat = tf.stack([attention_old, attention_cum], axis=2) processed_query = self.query_layer(tf.expand_dims(query, 1)) processed_attn = self.location_dense(self.location_conv1d(attn_cat)) - score = self.v( - tf.nn.tanh(self.processed_values + processed_query + - processed_attn)) + score = self.v(tf.nn.tanh(self.processed_values + processed_query + processed_attn)) score = tf.squeeze(score, axis=2) return score, processed_query @@ -152,14 +164,14 @@ class Attention(keras.layers.Layer): score = tf.squeeze(score, axis=2) return score, processed_query - def apply_score_masking(self, score, mask): #pylint: disable=no-self-use + def apply_score_masking(self, score, mask): # pylint: disable=no-self-use """ ignore sequence paddings """ padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2) # Bias so padding positions do not contribute to attention distribution. - score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32) + score -= 1.0e9 * math_ops.cast(padding_mask, dtype=tf.float32) return score - def apply_forward_attention(self, alignment, alpha): #pylint: disable=no-self-use + def apply_forward_attention(self, alignment, alpha): # pylint: disable=no-self-use # forward attention fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)), constant_values=0.0) # compute transition potentials @@ -206,7 +218,9 @@ class Attention(keras.layers.Layer): states = self.update_states(states, scores_norm, attn_weights, new_alpha) # context_vector shape after sum == (batch_size, hidden_size) - context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False) + context_vector = tf.matmul( + tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False + ) context_vector = tf.squeeze(context_vector, axis=1) return context_vector, attn_weights, states @@ -230,8 +244,7 @@ class Attention(keras.layers.Layer): # location_attention_filters=32, # location_attention_kernel_size=31): -# super(LocationSensitiveAttention, -# self).__init__(units=units, +# super( self).__init__(units=units, # memory=memory, # memory_sequence_length=memory_sequence_length, # normalize=normalize, diff --git a/TTS/tts/tf/layers/tacotron/tacotron2.py b/TTS/tts/tf/layers/tacotron/tacotron2.py index 094d9e4a..3247a8c4 100644 --- a/TTS/tts/tf/layers/tacotron/tacotron2.py +++ b/TTS/tts/tf/layers/tacotron/tacotron2.py @@ -5,15 +5,17 @@ from TTS.tts.tf.layers.tacotron.common_layers import Prenet, Attention # NOTE: linter has a problem with the current TF release -#pylint: disable=no-value-for-parameter -#pylint: disable=unexpected-keyword-arg +# pylint: disable=no-value-for-parameter +# pylint: disable=unexpected-keyword-arg class ConvBNBlock(keras.layers.Layer): def __init__(self, filters, kernel_size, activation, **kwargs): - super(ConvBNBlock, self).__init__(**kwargs) - self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding='same', name='convolution1d') - self.batch_normalization = keras.layers.BatchNormalization(axis=2, momentum=0.90, epsilon=1e-5, name='batch_normalization') - self.dropout = keras.layers.Dropout(rate=0.5, name='dropout') - self.activation = keras.layers.Activation(activation, name='activation') + super().__init__(**kwargs) + self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding="same", name="convolution1d") + self.batch_normalization = keras.layers.BatchNormalization( + axis=2, momentum=0.90, epsilon=1e-5, name="batch_normalization" + ) + self.dropout = keras.layers.Dropout(rate=0.5, name="dropout") + self.activation = keras.layers.Activation(activation, name="activation") def call(self, x, training=None): o = self.convolution1d(x) @@ -25,12 +27,12 @@ class ConvBNBlock(keras.layers.Layer): class Postnet(keras.layers.Layer): def __init__(self, output_filters, num_convs, **kwargs): - super(Postnet, self).__init__(**kwargs) + super().__init__(**kwargs) self.convolutions = [] - self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name='convolutions_0')) + self.convolutions.append(ConvBNBlock(512, 5, "tanh", name="convolutions_0")) for idx in range(1, num_convs - 1): - self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name=f'convolutions_{idx}')) - self.convolutions.append(ConvBNBlock(output_filters, 5, 'linear', name=f'convolutions_{idx+1}')) + self.convolutions.append(ConvBNBlock(512, 5, "tanh", name=f"convolutions_{idx}")) + self.convolutions.append(ConvBNBlock(output_filters, 5, "linear", name=f"convolutions_{idx+1}")) def call(self, x, training=None): o = x @@ -41,11 +43,13 @@ class Postnet(keras.layers.Layer): class Encoder(keras.layers.Layer): def __init__(self, output_input_dim, **kwargs): - super(Encoder, self).__init__(**kwargs) + super().__init__(**kwargs) self.convolutions = [] for idx in range(3): - self.convolutions.append(ConvBNBlock(output_input_dim, 5, 'relu', name=f'convolutions_{idx}')) - self.lstm = keras.layers.Bidirectional(keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name='lstm') + self.convolutions.append(ConvBNBlock(output_input_dim, 5, "relu", name=f"convolutions_{idx}")) + self.lstm = keras.layers.Bidirectional( + keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name="lstm" + ) def call(self, x, training=None): o = x @@ -56,11 +60,27 @@ class Encoder(keras.layers.Layer): class Decoder(keras.layers.Layer): - #pylint: disable=unused-argument - def __init__(self, frame_dim, r, attn_type, use_attn_win, attn_norm, prenet_type, - prenet_dropout, use_forward_attn, use_trans_agent, use_forward_attn_mask, - use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, enable_tflite, **kwargs): - super(Decoder, self).__init__(**kwargs) + # pylint: disable=unused-argument + def __init__( + self, + frame_dim, + r, + attn_type, + use_attn_win, + attn_norm, + prenet_type, + prenet_dropout, + use_forward_attn, + use_trans_agent, + use_forward_attn_mask, + use_location_attn, + attn_K, + separate_stopnet, + speaker_emb_dim, + enable_tflite, + **kwargs, + ): + super().__init__(**kwargs) self.frame_dim = frame_dim self.r_init = tf.constant(r, dtype=tf.int32) self.r = tf.constant(r, dtype=tf.int32) @@ -80,30 +100,31 @@ class Decoder(keras.layers.Layer): self.p_attention_dropout = 0.1 self.p_decoder_dropout = 0.1 - self.prenet = Prenet(prenet_type, - prenet_dropout, - [self.prenet_dim, self.prenet_dim], - bias=False, - name='prenet') - self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name='attention_rnn', ) + self.prenet = Prenet(prenet_type, prenet_dropout, [self.prenet_dim, self.prenet_dim], bias=False, name="prenet") + self.attention_rnn = keras.layers.LSTMCell( + self.query_dim, + use_bias=True, + name="attention_rnn", + ) self.attention_rnn_dropout = keras.layers.Dropout(0.5) # TODO: implement other attn options - self.attention = Attention(attn_dim=self.attn_dim, - use_loc_attn=True, - loc_attn_n_filters=32, - loc_attn_kernel_size=31, - use_windowing=False, - norm=attn_norm, - use_forward_attn=use_forward_attn, - use_trans_agent=use_trans_agent, - use_forward_attn_mask=use_forward_attn_mask, - name='attention') - self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name='decoder_rnn') + self.attention = Attention( + attn_dim=self.attn_dim, + use_loc_attn=True, + loc_attn_n_filters=32, + loc_attn_kernel_size=31, + use_windowing=False, + norm=attn_norm, + use_forward_attn=use_forward_attn, + use_trans_agent=use_trans_agent, + use_forward_attn_mask=use_forward_attn_mask, + name="attention", + ) + self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name="decoder_rnn") self.decoder_rnn_dropout = keras.layers.Dropout(0.5) - self.linear_projection = keras.layers.Dense(self.frame_dim * r, name='linear_projection/linear_layer') - self.stopnet = keras.layers.Dense(1, name='stopnet/linear_layer') - + self.linear_projection = keras.layers.Dense(self.frame_dim * r, name="linear_projection/linear_layer") + self.stopnet = keras.layers.Dense(1, name="stopnet/linear_layer") def set_max_decoder_steps(self, new_max_steps): self.max_decoder_steps = tf.constant(new_max_steps, dtype=tf.int32) @@ -120,25 +141,31 @@ class Decoder(keras.layers.Layer): attention_states = self.attention.init_states(batch_size, memory_length) return zero_frame, zero_context, attention_rnn_state, decoder_rnn_state, attention_states - def step(self, prenet_next, states, - memory_seq_length=None, training=None): + def step(self, prenet_next, states, memory_seq_length=None, training=None): _, context_next, attention_rnn_state, decoder_rnn_state, attention_states = states attention_rnn_input = tf.concat([prenet_next, context_next], -1) - attention_rnn_output, attention_rnn_state = \ - self.attention_rnn(attention_rnn_input, - attention_rnn_state, training=training) + attention_rnn_output, attention_rnn_state = self.attention_rnn( + attention_rnn_input, attention_rnn_state, training=training + ) attention_rnn_output = self.attention_rnn_dropout(attention_rnn_output, training=training) context, attention, attention_states = self.attention(attention_rnn_output, attention_states, training=training) decoder_rnn_input = tf.concat([attention_rnn_output, context], -1) - decoder_rnn_output, decoder_rnn_state = \ - self.decoder_rnn(decoder_rnn_input, decoder_rnn_state, training=training) + decoder_rnn_output, decoder_rnn_state = self.decoder_rnn( + decoder_rnn_input, decoder_rnn_state, training=training + ) decoder_rnn_output = self.decoder_rnn_dropout(decoder_rnn_output, training=training) linear_projection_input = tf.concat([decoder_rnn_output, context], -1) output_frame = self.linear_projection(linear_projection_input, training=training) stopnet_input = tf.concat([decoder_rnn_output, output_frame], -1) stopnet_output = self.stopnet(stopnet_input, training=training) - output_frame = output_frame[:, :self.r * self.frame_dim] - states = (output_frame[:, self.frame_dim * (self.r - 1):], context, attention_rnn_state, decoder_rnn_state, attention_states) + output_frame = output_frame[:, : self.r * self.frame_dim] + states = ( + output_frame[:, self.frame_dim * (self.r - 1) :], + context, + attention_rnn_state, + decoder_rnn_state, + attention_states, + ) return output_frame, stopnet_output, states, attention def decode(self, memory, states, frames, memory_seq_length=None): @@ -157,21 +184,20 @@ class Decoder(keras.layers.Layer): def _body(step, memory, prenet_output, states, outputs, stop_tokens, attentions): prenet_next = prenet_output[:, step] - output, stop_token, states, attention = self.step(prenet_next, - states, - memory_seq_length) + output, stop_token, states, attention = self.step(prenet_next, states, memory_seq_length) outputs = outputs.write(step, output) attentions = attentions.write(step, attention) stop_tokens = stop_tokens.write(step, stop_token) return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions - _, memory, _, states, outputs, stop_tokens, attentions = \ - tf.while_loop(lambda *arg: True, - _body, - loop_vars=(step_count, memory, prenet_output, - states, outputs, stop_tokens, attentions), - parallel_iterations=32, - swap_memory=True, - maximum_iterations=num_iter) + + _, memory, _, states, outputs, stop_tokens, attentions = tf.while_loop( + lambda *arg: True, + _body, + loop_vars=(step_count, memory, prenet_output, states, outputs, stop_tokens, attentions), + parallel_iterations=32, + swap_memory=True, + maximum_iterations=num_iter, + ) outputs = outputs.stack() attentions = attentions.stack() @@ -200,10 +226,7 @@ class Decoder(keras.layers.Layer): def _body(step, memory, states, outputs, stop_tokens, attentions, stop_flag): frame_next = states[0] prenet_next = self.prenet(frame_next, training=False) - output, stop_token, states, attention = self.step(prenet_next, - states, - None, - training=False) + output, stop_token, states, attention = self.step(prenet_next, states, None, training=False) stop_token = tf.math.sigmoid(stop_token) outputs = outputs.write(step, output) attentions = attentions.write(step, attention) @@ -213,14 +236,14 @@ class Decoder(keras.layers.Layer): return step + 1, memory, states, outputs, stop_tokens, attentions, stop_flag cond = lambda step, m, s, o, st, a, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool)) - _, memory, states, outputs, stop_tokens, attentions, stop_flag = \ - tf.while_loop(cond, - _body, - loop_vars=(step_count, memory, states, outputs, - stop_tokens, attentions, stop_flag), - parallel_iterations=32, - swap_memory=True, - maximum_iterations=self.max_decoder_steps) + _, memory, states, outputs, stop_tokens, attentions, stop_flag = tf.while_loop( + cond, + _body, + loop_vars=(step_count, memory, states, outputs, stop_tokens, attentions, stop_flag), + parallel_iterations=32, + swap_memory=True, + maximum_iterations=self.max_decoder_steps, + ) outputs = outputs.stack() attentions = attentions.stack() @@ -238,12 +261,13 @@ class Decoder(keras.layers.Layer): batch_size is 1""" # init states # dynamic_shape is not supported in TFLite - outputs = tf.TensorArray(dtype=tf.float32, - size=self.max_decoder_steps, - element_shape=tf.TensorShape( - [self.output_dim]), - clear_after_read=False, - dynamic_size=False) + outputs = tf.TensorArray( + dtype=tf.float32, + size=self.max_decoder_steps, + element_shape=tf.TensorShape([self.output_dim]), + clear_after_read=False, + dynamic_size=False, + ) # stop_flags = tf.TensorArray(dtype=tf.bool, # size=self.max_decoder_steps, # element_shape=tf.TensorShape( @@ -263,10 +287,7 @@ class Decoder(keras.layers.Layer): def _body(step, memory, states, outputs, stop_flag): frame_next = states[0] prenet_next = self.prenet(frame_next, training=False) - output, stop_token, states, _ = self.step(prenet_next, - states, - None, - training=False) + output, stop_token, states, _ = self.step(prenet_next, states, None, training=False) stop_token = tf.math.sigmoid(stop_token) stop_flag = tf.greater(stop_token, self.stop_thresh) stop_flag = tf.reduce_all(stop_flag) @@ -276,24 +297,22 @@ class Decoder(keras.layers.Layer): return step + 1, memory, states, outputs, stop_flag cond = lambda step, m, s, o, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool)) - step_count, memory, states, outputs, stop_flag = \ - tf.while_loop(cond, - _body, - loop_vars=(step_count, memory, states, outputs, - stop_flag), - parallel_iterations=32, - swap_memory=True, - maximum_iterations=self.max_decoder_steps) - + step_count, memory, states, outputs, stop_flag = tf.while_loop( + cond, + _body, + loop_vars=(step_count, memory, states, outputs, stop_flag), + parallel_iterations=32, + swap_memory=True, + maximum_iterations=self.max_decoder_steps, + ) outputs = outputs.stack() - outputs = tf.gather(outputs, tf.range(step_count)) # pylint: disable=no-value-for-parameter + outputs = tf.gather(outputs, tf.range(step_count)) # pylint: disable=no-value-for-parameter outputs = tf.expand_dims(outputs, axis=[0]) outputs = tf.transpose(outputs, [1, 0, 2]) outputs = tf.reshape(outputs, [1, -1, self.frame_dim]) return outputs, stop_tokens, attentions - def call(self, memory, states, frames=None, memory_seq_length=None, training=False): if training: return self.decode(memory, states, frames, memory_seq_length) diff --git a/TTS/tts/tf/models/tacotron2.py b/TTS/tts/tf/models/tacotron2.py index 882af517..5a0c1977 100644 --- a/TTS/tts/tf/models/tacotron2.py +++ b/TTS/tts/tf/models/tacotron2.py @@ -5,28 +5,30 @@ from TTS.tts.tf.layers.tacotron.tacotron2 import Encoder, Decoder, Postnet from TTS.tts.tf.utils.tf_utils import shape_list -#pylint: disable=too-many-ancestors, abstract-method +# pylint: disable=too-many-ancestors, abstract-method class Tacotron2(keras.models.Model): - def __init__(self, - num_chars, - num_speakers, - r, - postnet_output_dim=80, - decoder_output_dim=80, - attn_type='original', - attn_win=False, - attn_norm="softmax", - attn_K=4, - prenet_type="original", - prenet_dropout=True, - forward_attn=False, - trans_agent=False, - forward_attn_mask=False, - location_attn=True, - separate_stopnet=True, - bidirectional_decoder=False, - enable_tflite=False): - super(Tacotron2, self).__init__() + def __init__( + self, + num_chars, + num_speakers, + r, + postnet_output_dim=80, + decoder_output_dim=80, + attn_type="original", + attn_win=False, + attn_norm="softmax", + attn_K=4, + prenet_type="original", + prenet_dropout=True, + forward_attn=False, + trans_agent=False, + forward_attn_mask=False, + location_attn=True, + separate_stopnet=True, + bidirectional_decoder=False, + enable_tflite=False, + ): + super().__init__() self.r = r self.decoder_output_dim = decoder_output_dim self.postnet_output_dim = postnet_output_dim @@ -35,26 +37,28 @@ class Tacotron2(keras.models.Model): self.speaker_embed_dim = 256 self.enable_tflite = enable_tflite - self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding') - self.encoder = Encoder(512, name='encoder') + self.embedding = keras.layers.Embedding(num_chars, 512, name="embedding") + self.encoder = Encoder(512, name="encoder") # TODO: most of the decoder args have no use at the momment - self.decoder = Decoder(decoder_output_dim, - r, - attn_type=attn_type, - use_attn_win=attn_win, - attn_norm=attn_norm, - prenet_type=prenet_type, - prenet_dropout=prenet_dropout, - use_forward_attn=forward_attn, - use_trans_agent=trans_agent, - use_forward_attn_mask=forward_attn_mask, - use_location_attn=location_attn, - attn_K=attn_K, - separate_stopnet=separate_stopnet, - speaker_emb_dim=self.speaker_embed_dim, - name='decoder', - enable_tflite=enable_tflite) - self.postnet = Postnet(postnet_output_dim, 5, name='postnet') + self.decoder = Decoder( + decoder_output_dim, + r, + attn_type=attn_type, + use_attn_win=attn_win, + attn_norm=attn_norm, + prenet_type=prenet_type, + prenet_dropout=prenet_dropout, + use_forward_attn=forward_attn, + use_trans_agent=trans_agent, + use_forward_attn_mask=forward_attn_mask, + use_location_attn=location_attn, + attn_K=attn_K, + separate_stopnet=separate_stopnet, + speaker_emb_dim=self.speaker_embed_dim, + name="decoder", + enable_tflite=enable_tflite, + ) + self.postnet = Postnet(postnet_output_dim, 5, name="postnet") @tf.function(experimental_relax_shapes=True) def call(self, characters, text_lengths=None, frames=None, training=None): @@ -62,14 +66,16 @@ class Tacotron2(keras.models.Model): return self.training(characters, text_lengths, frames) if not training: return self.inference(characters) - raise RuntimeError(' [!] Set model training mode True or False') + raise RuntimeError(" [!] Set model training mode True or False") def training(self, characters, text_lengths, frames): B, T = shape_list(characters) embedding_vectors = self.embedding(characters, training=True) encoder_output = self.encoder(embedding_vectors, training=True) decoder_states = self.decoder.build_decoder_initial_states(B, 512, T) - decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, frames, text_lengths, training=True) + decoder_frames, stop_tokens, attentions = self.decoder( + encoder_output, decoder_states, frames, text_lengths, training=True + ) postnet_frames = self.postnet(decoder_frames, training=True) output_frames = decoder_frames + postnet_frames return decoder_frames, output_frames, attentions, stop_tokens @@ -89,7 +95,8 @@ class Tacotron2(keras.models.Model): experimental_relax_shapes=True, input_signature=[ tf.TensorSpec([1, None], dtype=tf.int32), - ],) + ], + ) def inference_tflite(self, characters): B, T = shape_list(characters) embedding_vectors = self.embedding(characters, training=False) @@ -101,7 +108,9 @@ class Tacotron2(keras.models.Model): print(output_frames.shape) return decoder_frames, output_frames, attentions, stop_tokens - def build_inference(self, ): + def build_inference( + self, + ): # TODO: issue https://github.com/PyCQA/pylint/issues/3613 - input_ids = tf.random.uniform(shape=[1, 4], maxval=10, dtype=tf.int32) #pylint: disable=unexpected-keyword-arg + input_ids = tf.random.uniform(shape=[1, 4], maxval=10, dtype=tf.int32) # pylint: disable=unexpected-keyword-arg self(input_ids) diff --git a/TTS/tts/tf/utils/convert_torch_to_tf_utils.py b/TTS/tts/tf/utils/convert_torch_to_tf_utils.py index 03b41803..5cc072d0 100644 --- a/TTS/tts/tf/utils/convert_torch_to_tf_utils.py +++ b/TTS/tts/tf/utils/convert_torch_to_tf_utils.py @@ -2,8 +2,9 @@ import numpy as np import tensorflow as tf # NOTE: linter has a problem with the current TF release -#pylint: disable=no-value-for-parameter -#pylint: disable=unexpected-keyword-arg +# pylint: disable=no-value-for-parameter +# pylint: disable=unexpected-keyword-arg + def tf_create_dummy_inputs(): """ Create dummy inputs for TF Tacotron2 model """ @@ -13,11 +14,11 @@ def tf_create_dummy_inputs(): pad = 1 n_chars = 24 input_ids = tf.random.uniform([batch_size, max_input_length + pad], maxval=n_chars, dtype=tf.int32) - input_lengths = np.random.randint(0, high=max_input_length+1 + pad, size=[batch_size]) + input_lengths = np.random.randint(0, high=max_input_length + 1 + pad, size=[batch_size]) input_lengths[-1] = max_input_length input_lengths = tf.convert_to_tensor(input_lengths, dtype=tf.int32) mel_outputs = tf.random.uniform(shape=[batch_size, max_mel_length + pad, 80]) - mel_lengths = np.random.randint(0, high=max_mel_length+1 + pad, size=[batch_size]) + mel_lengths = np.random.randint(0, high=max_mel_length + 1 + pad, size=[batch_size]) mel_lengths[-1] = max_mel_length mel_lengths = tf.convert_to_tensor(mel_lengths, dtype=tf.int32) return input_ids, input_lengths, mel_outputs, mel_lengths @@ -31,14 +32,14 @@ def compare_torch_tf(torch_tensor, tf_tensor): def convert_tf_name(tf_name): """ Convert certain patterns in TF layer names to Torch patterns """ tf_name_tmp = tf_name - tf_name_tmp = tf_name_tmp.replace(':0', '') - tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_1/recurrent_kernel', '/weight_hh_l0') - tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_2/kernel', '/weight_ih_l1') - tf_name_tmp = tf_name_tmp.replace('/recurrent_kernel', '/weight_hh') - tf_name_tmp = tf_name_tmp.replace('/kernel', '/weight') - tf_name_tmp = tf_name_tmp.replace('/gamma', '/weight') - tf_name_tmp = tf_name_tmp.replace('/beta', '/bias') - tf_name_tmp = tf_name_tmp.replace('/', '.') + tf_name_tmp = tf_name_tmp.replace(":0", "") + tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_1/recurrent_kernel", "/weight_hh_l0") + tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_2/kernel", "/weight_ih_l1") + tf_name_tmp = tf_name_tmp.replace("/recurrent_kernel", "/weight_hh") + tf_name_tmp = tf_name_tmp.replace("/kernel", "/weight") + tf_name_tmp = tf_name_tmp.replace("/gamma", "/weight") + tf_name_tmp = tf_name_tmp.replace("/beta", "/bias") + tf_name_tmp = tf_name_tmp.replace("/", ".") return tf_name_tmp @@ -47,33 +48,35 @@ def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict): print(" > Passing weights from Torch to TF ...") for tf_var in tf_vars: torch_var_name = var_map_dict[tf_var.name] - print(f' | > {tf_var.name} <-- {torch_var_name}') + print(f" | > {tf_var.name} <-- {torch_var_name}") # if tuple, it is a bias variable if not isinstance(torch_var_name, tuple): - torch_layer_name = '.'.join(torch_var_name.split('.')[-2:]) + torch_layer_name = ".".join(torch_var_name.split(".")[-2:]) torch_weight = state_dict[torch_var_name] - if 'convolution1d/kernel' in tf_var.name or 'conv1d/kernel' in tf_var.name: + if "convolution1d/kernel" in tf_var.name or "conv1d/kernel" in tf_var.name: # out_dim, in_dim, filter -> filter, in_dim, out_dim numpy_weight = torch_weight.permute([2, 1, 0]).detach().cpu().numpy() - elif 'lstm_cell' in tf_var.name and 'kernel' in tf_var.name: + elif "lstm_cell" in tf_var.name and "kernel" in tf_var.name: numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy() # if variable is for bidirectional lstm and it is a bias vector there # needs to be pre-defined two matching torch bias vectors - elif '_lstm/lstm_cell_' in tf_var.name and 'bias' in tf_var.name: + elif "_lstm/lstm_cell_" in tf_var.name and "bias" in tf_var.name: bias_vectors = [value for key, value in state_dict.items() if key in torch_var_name] assert len(bias_vectors) == 2 numpy_weight = bias_vectors[0] + bias_vectors[1] - elif 'rnn' in tf_var.name and 'kernel' in tf_var.name: + elif "rnn" in tf_var.name and "kernel" in tf_var.name: numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy() - elif 'rnn' in tf_var.name and 'bias' in tf_var.name: + elif "rnn" in tf_var.name and "bias" in tf_var.name: bias_vectors = [value for key, value in state_dict.items() if torch_var_name[:-2] in key] assert len(bias_vectors) == 2 numpy_weight = bias_vectors[0] + bias_vectors[1] - elif 'linear_layer' in torch_layer_name and 'weight' in torch_var_name: + elif "linear_layer" in torch_layer_name and "weight" in torch_var_name: numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy() else: numpy_weight = torch_weight.detach().cpu().numpy() - assert np.all(tf_var.shape == numpy_weight.shape), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}" + assert np.all( + tf_var.shape == numpy_weight.shape + ), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}" tf.keras.backend.set_value(tf_var, numpy_weight) return tf_vars diff --git a/TTS/tts/tf/utils/generic_utils.py b/TTS/tts/tf/utils/generic_utils.py index 7eba946b..8956b47e 100644 --- a/TTS/tts/tf/utils/generic_utils.py +++ b/TTS/tts/tf/utils/generic_utils.py @@ -7,20 +7,20 @@ import tensorflow as tf def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs): state = { - 'model': model.weights, - 'optimizer': optimizer, - 'step': current_step, - 'epoch': epoch, - 'date': datetime.date.today().strftime("%B %d, %Y"), - 'r': r + "model": model.weights, + "optimizer": optimizer, + "step": current_step, + "epoch": epoch, + "date": datetime.date.today().strftime("%B %d, %Y"), + "r": r, } state.update(kwargs) - pickle.dump(state, open(output_path, 'wb')) + pickle.dump(state, open(output_path, "wb")) def load_checkpoint(model, checkpoint_path): - checkpoint = pickle.load(open(checkpoint_path, 'rb')) - chkp_var_dict = {var.name: var.numpy() for var in checkpoint['model']} + checkpoint = pickle.load(open(checkpoint_path, "rb")) + chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} tf_vars = model.weights for tf_var in tf_vars: layer_name = tf_var.name @@ -32,8 +32,8 @@ def load_checkpoint(model, checkpoint_path): chkp_var_value = chkp_var_dict[layer_name] tf.keras.backend.set_value(tf_var, chkp_var_value) - if 'r' in checkpoint.keys(): - model.decoder.set_r(checkpoint['r']) + if "r" in checkpoint.keys(): + model.decoder.set_r(checkpoint["r"]) return model @@ -45,8 +45,7 @@ def sequence_mask(sequence_length, max_len=None): seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) if sequence_length.is_cuda: seq_range_expand = seq_range_expand.cuda() - seq_length_expand = ( - sequence_length.unsqueeze(1).expand_as(seq_range_expand)) + seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand) # B x T_max return seq_range_expand < seq_length_expand @@ -62,42 +61,42 @@ def count_parameters(model, c): try: return model.count_params() except RuntimeError: - input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype('int32')) - input_lengths = np.random.randint(100, 129, (8, )) + input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype("int32")) + input_lengths = np.random.randint(100, 129, (8,)) input_lengths[-1] = 128 - input_lengths = tf.convert_to_tensor(input_lengths.astype('int32')) - mel_spec = np.random.rand(8, 2 * c.r, - c.audio['num_mels']).astype('float32') + input_lengths = tf.convert_to_tensor(input_lengths.astype("int32")) + mel_spec = np.random.rand(8, 2 * c.r, c.audio["num_mels"]).astype("float32") mel_spec = tf.convert_to_tensor(mel_spec) - speaker_ids = np.random.randint( - 0, 5, (8, )) if c.use_speaker_embedding else None + speaker_ids = np.random.randint(0, 5, (8,)) if c.use_speaker_embedding else None _ = model(input_dummy, input_lengths, mel_spec, speaker_ids=speaker_ids) return model.count_params() def setup_model(num_chars, num_speakers, c, enable_tflite=False): print(" > Using model: {}".format(c.model)) - MyModel = importlib.import_module('TTS.tts.tf.models.' + c.model.lower()) + MyModel = importlib.import_module("TTS.tts.tf.models." + c.model.lower()) MyModel = getattr(MyModel, c.model) if c.model.lower() in "tacotron": - raise NotImplementedError(' [!] Tacotron model is not ready.') + raise NotImplementedError(" [!] Tacotron model is not ready.") # tacotron2 - model = MyModel(num_chars=num_chars, - num_speakers=num_speakers, - r=c.r, - postnet_output_dim=c.audio['num_mels'], - decoder_output_dim=c.audio['num_mels'], - attn_type=c.attention_type, - attn_win=c.windowing, - attn_norm=c.attention_norm, - prenet_type=c.prenet_type, - prenet_dropout=c.prenet_dropout, - forward_attn=c.use_forward_attn, - trans_agent=c.transition_agent, - forward_attn_mask=c.forward_attn_mask, - location_attn=c.location_attn, - attn_K=c.attention_heads, - separate_stopnet=c.separate_stopnet, - bidirectional_decoder=c.bidirectional_decoder, - enable_tflite=enable_tflite) + model = MyModel( + num_chars=num_chars, + num_speakers=num_speakers, + r=c.r, + postnet_output_dim=c.audio["num_mels"], + decoder_output_dim=c.audio["num_mels"], + attn_type=c.attention_type, + attn_win=c.windowing, + attn_norm=c.attention_norm, + prenet_type=c.prenet_type, + prenet_dropout=c.prenet_dropout, + forward_attn=c.use_forward_attn, + trans_agent=c.transition_agent, + forward_attn_mask=c.forward_attn_mask, + location_attn=c.location_attn, + attn_K=c.attention_heads, + separate_stopnet=c.separate_stopnet, + bidirectional_decoder=c.bidirectional_decoder, + enable_tflite=enable_tflite, + ) return model diff --git a/TTS/tts/tf/utils/io.py b/TTS/tts/tf/utils/io.py index 143422d2..06c1c9fb 100644 --- a/TTS/tts/tf/utils/io.py +++ b/TTS/tts/tf/utils/io.py @@ -5,20 +5,20 @@ import tensorflow as tf def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs): state = { - 'model': model.weights, - 'optimizer': optimizer, - 'step': current_step, - 'epoch': epoch, - 'date': datetime.date.today().strftime("%B %d, %Y"), - 'r': r + "model": model.weights, + "optimizer": optimizer, + "step": current_step, + "epoch": epoch, + "date": datetime.date.today().strftime("%B %d, %Y"), + "r": r, } state.update(kwargs) - pickle.dump(state, open(output_path, 'wb')) + pickle.dump(state, open(output_path, "wb")) def load_checkpoint(model, checkpoint_path): - checkpoint = pickle.load(open(checkpoint_path, 'rb')) - chkp_var_dict = {var.name: var.numpy() for var in checkpoint['model']} + checkpoint = pickle.load(open(checkpoint_path, "rb")) + chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} tf_vars = model.weights for tf_var in tf_vars: layer_name = tf_var.name @@ -30,8 +30,8 @@ def load_checkpoint(model, checkpoint_path): chkp_var_value = chkp_var_dict[layer_name] tf.keras.backend.set_value(tf_var, chkp_var_value) - if 'r' in checkpoint.keys(): - model.decoder.set_r(checkpoint['r']) + if "r" in checkpoint.keys(): + model.decoder.set_r(checkpoint["r"]) return model diff --git a/TTS/tts/tf/utils/tflite.py b/TTS/tts/tf/utils/tflite.py index b8daf254..9701d591 100644 --- a/TTS/tts/tf/utils/tflite.py +++ b/TTS/tts/tf/utils/tflite.py @@ -1,25 +1,20 @@ import tensorflow as tf -def convert_tacotron2_to_tflite(model, - output_path=None, - experimental_converter=True): +def convert_tacotron2_to_tflite(model, output_path=None, experimental_converter=True): """Convert Tensorflow Tacotron2 model to TFLite. Save a binary file if output_path is provided, else return TFLite model.""" concrete_function = model.inference_tflite.get_concrete_function() - converter = tf.lite.TFLiteConverter.from_concrete_functions( - [concrete_function]) + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function]) converter.experimental_new_converter = experimental_converter converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.target_spec.supported_ops = [ - tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS - ] + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] tflite_model = converter.convert() - print(f'Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.') + print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.") if output_path is not None: # same model binary if outputpath is provided - with open(output_path, 'wb') as f: + with open(output_path, "wb") as f: f.write(tflite_model) return None return tflite_model diff --git a/TTS/tts/utils/chinese_mandarin/numbers.py b/TTS/tts/utils/chinese_mandarin/numbers.py index 94c8fd03..adb21142 100644 --- a/TTS/tts/utils/chinese_mandarin/numbers.py +++ b/TTS/tts/utils/chinese_mandarin/numbers.py @@ -1,4 +1,3 @@ - #!/usr/bin/env python3 # -*- coding: utf-8 -*- @@ -31,38 +30,37 @@ def _num2chinese(num: str, big=False, simp=True, o=False, twoalt=False) -> str: # check num first nd = str(num) if abs(float(nd)) >= 1e48: - raise ValueError('number out of range') - if 'e' in nd: - raise ValueError('scientific notation is not supported') - c_symbol = '正负点' if simp else '正負點' + raise ValueError("number out of range") + if "e" in nd: + raise ValueError("scientific notation is not supported") + c_symbol = "正负点" if simp else "正負點" if o: # formal twoalt = False if big: - c_basic = '零壹贰叁肆伍陆柒捌玖' if simp else '零壹貳參肆伍陸柒捌玖' - c_unit1 = '拾佰仟' - c_twoalt = '贰' if simp else '貳' + c_basic = "零壹贰叁肆伍陆柒捌玖" if simp else "零壹貳參肆伍陸柒捌玖" + c_unit1 = "拾佰仟" + c_twoalt = "贰" if simp else "貳" else: - c_basic = '〇一二三四五六七八九' if o else '零一二三四五六七八九' - c_unit1 = '十百千' + c_basic = "〇一二三四五六七八九" if o else "零一二三四五六七八九" + c_unit1 = "十百千" if twoalt: - c_twoalt = '两' if simp else '兩' + c_twoalt = "两" if simp else "兩" else: - c_twoalt = '二' - c_unit2 = '万亿兆京垓秭穰沟涧正载' if simp else '萬億兆京垓秭穰溝澗正載' - revuniq = lambda l: ''.join(k for k, g in itertools.groupby(reversed(l))) + c_twoalt = "二" + c_unit2 = "万亿兆京垓秭穰沟涧正载" if simp else "萬億兆京垓秭穰溝澗正載" + revuniq = lambda l: "".join(k for k, g in itertools.groupby(reversed(l))) nd = str(num) result = [] - if nd[0] == '+': + if nd[0] == "+": result.append(c_symbol[0]) - elif nd[0] == '-': + elif nd[0] == "-": result.append(c_symbol[1]) - if '.' in nd: - integer, remainder = nd.lstrip('+-').split('.') + if "." in nd: + integer, remainder = nd.lstrip("+-").split(".") else: - integer, remainder = nd.lstrip('+-'), None + integer, remainder = nd.lstrip("+-"), None if int(integer): - splitted = [integer[max(i - 4, 0):i] - for i in range(len(integer), 0, -4)] + splitted = [integer[max(i - 4, 0) : i] for i in range(len(integer), 0, -4)] intresult = [] for nu, unit in enumerate(splitted): # special cases @@ -75,17 +73,17 @@ def _num2chinese(num: str, big=False, simp=True, o=False, twoalt=False) -> str: ulist = [] unit = unit.zfill(4) for nc, ch in enumerate(reversed(unit)): - if ch == '0': + if ch == "0": if ulist: # ???0 ulist.append(c_basic[0]) elif nc == 0: ulist.append(c_basic[int(ch)]) - elif nc == 1 and ch == '1' and unit[1] == '0': + elif nc == 1 and ch == "1" and unit[1] == "0": # special case for tens # edit the 'elif' if you don't like # 十四, 三千零十四, 三千三百一十四 ulist.append(c_unit1[0]) - elif nc > 1 and ch == '2': + elif nc > 1 and ch == "2": ulist.append(c_twoalt + c_unit1[nc - 1]) else: ulist.append(c_basic[int(ch)] + c_unit1[nc - 1]) @@ -99,10 +97,8 @@ def _num2chinese(num: str, big=False, simp=True, o=False, twoalt=False) -> str: result.append(c_basic[0]) if remainder: result.append(c_symbol[2]) - result.append(''.join(c_basic[int(ch)] for ch in remainder)) - return ''.join(result) - - + result.append("".join(c_basic[int(ch)] for ch in remainder)) + return "".join(result) def _number_replace(match) -> str: @@ -127,5 +123,5 @@ def replace_numbers_to_characters_in_text(text: str) -> str: Returns: str: output text """ - text = re.sub(r'[0-9]+', _number_replace, text) + text = re.sub(r"[0-9]+", _number_replace, text) return text diff --git a/TTS/tts/utils/chinese_mandarin/phonemizer.py b/TTS/tts/utils/chinese_mandarin/phonemizer.py index 7742c491..7e46edf6 100644 --- a/TTS/tts/utils/chinese_mandarin/phonemizer.py +++ b/TTS/tts/utils/chinese_mandarin/phonemizer.py @@ -9,9 +9,7 @@ import jieba def _chinese_character_to_pinyin(text: str) -> List[str]: - pinyins = pypinyin.pinyin( - text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True - ) + pinyins = pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True) pinyins_flat_list = [item for sublist in pinyins for item in sublist] return pinyins_flat_list diff --git a/TTS/tts/utils/chinese_mandarin/pinyinToPhonemes.py b/TTS/tts/utils/chinese_mandarin/pinyinToPhonemes.py index a4722ff9..4e25c3a4 100644 --- a/TTS/tts/utils/chinese_mandarin/pinyinToPhonemes.py +++ b/TTS/tts/utils/chinese_mandarin/pinyinToPhonemes.py @@ -1,4 +1,3 @@ - PINYIN_DICT = { "a": ["a"], "ai": ["ai"], diff --git a/TTS/tts/utils/data.py b/TTS/tts/utils/data.py index a75410b4..a55d3a86 100644 --- a/TTS/tts/utils/data.py +++ b/TTS/tts/utils/data.py @@ -4,8 +4,7 @@ import numpy as np def _pad_data(x, length): _pad = 0 assert x.ndim == 1 - return np.pad( - x, (0, length - x.shape[0]), mode='constant', constant_values=_pad) + return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=_pad) def prepare_data(inputs): @@ -14,12 +13,9 @@ def prepare_data(inputs): def _pad_tensor(x, length): - _pad = 0. + _pad = 0.0 assert x.ndim == 2 - x = np.pad( - x, [[0, 0], [0, length - x.shape[1]]], - mode='constant', - constant_values=_pad) + x = np.pad(x, [[0, 0], [0, length - x.shape[1]]], mode="constant", constant_values=_pad) return x @@ -31,10 +27,9 @@ def prepare_tensor(inputs, out_steps): def _pad_stop_target(x, length): - _pad = 0. + _pad = 0.0 assert x.ndim == 1 - return np.pad( - x, (0, length - x.shape[0]), mode='constant', constant_values=_pad) + return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=_pad) def prepare_stop_target(inputs, out_steps): @@ -46,22 +41,18 @@ def prepare_stop_target(inputs, out_steps): def pad_per_step(inputs, pad_len): - return np.pad( - inputs, [[0, 0], [0, 0], [0, pad_len]], - mode='constant', - constant_values=0.0) + return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0) # pylint: disable=attribute-defined-outside-init -class StandardScaler(): - +class StandardScaler: def set_stats(self, mean, scale): self.mean_ = mean self.scale_ = scale def reset_stats(self): - delattr(self, 'mean_') - delattr(self, 'scale_') + delattr(self, "mean_") + delattr(self, "scale_") def transform(self, X): X = np.asarray(X) diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index 44d961ec..6e566f10 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -28,277 +28,334 @@ def split_dataset(items): return items_eval, items return items[:eval_split_size], items[eval_split_size:] + # from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1 def sequence_mask(sequence_length, max_len=None): if max_len is None: max_len = sequence_length.data.max() - seq_range = torch.arange(max_len, - dtype=sequence_length.dtype, - device=sequence_length.device) + seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device) # B x T_max return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1) def to_camel(text): text = text.capitalize() - text = re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) - text = text.replace('Tts', 'TTS') + text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) + text = text.replace("Tts", "TTS") return text def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None): print(" > Using model: {}".format(c.model)) - MyModel = importlib.import_module('TTS.tts.models.' + c.model.lower()) + MyModel = importlib.import_module("TTS.tts.models." + c.model.lower()) MyModel = getattr(MyModel, to_camel(c.model)) if c.model.lower() in "tacotron": - model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), - num_speakers=num_speakers, - r=c.r, - postnet_output_dim=int(c.audio['fft_size'] / 2 + 1), - decoder_output_dim=c.audio['num_mels'], - gst=c.use_gst, - gst_embedding_dim=c.gst['gst_embedding_dim'], - gst_num_heads=c.gst['gst_num_heads'], - gst_style_tokens=c.gst['gst_style_tokens'], - gst_use_speaker_embedding=c.gst['gst_use_speaker_embedding'], - memory_size=c.memory_size, - attn_type=c.attention_type, - attn_win=c.windowing, - attn_norm=c.attention_norm, - prenet_type=c.prenet_type, - prenet_dropout=c.prenet_dropout, - forward_attn=c.use_forward_attn, - trans_agent=c.transition_agent, - forward_attn_mask=c.forward_attn_mask, - location_attn=c.location_attn, - attn_K=c.attention_heads, - separate_stopnet=c.separate_stopnet, - bidirectional_decoder=c.bidirectional_decoder, - double_decoder_consistency=c.double_decoder_consistency, - ddc_r=c.ddc_r, - speaker_embedding_dim=speaker_embedding_dim) + model = MyModel( + num_chars=num_chars + getattr(c, "add_blank", False), + num_speakers=num_speakers, + r=c.r, + postnet_output_dim=int(c.audio["fft_size"] / 2 + 1), + decoder_output_dim=c.audio["num_mels"], + gst=c.use_gst, + gst_embedding_dim=c.gst["gst_embedding_dim"], + gst_num_heads=c.gst["gst_num_heads"], + gst_style_tokens=c.gst["gst_style_tokens"], + gst_use_speaker_embedding=c.gst["gst_use_speaker_embedding"], + memory_size=c.memory_size, + attn_type=c.attention_type, + attn_win=c.windowing, + attn_norm=c.attention_norm, + prenet_type=c.prenet_type, + prenet_dropout=c.prenet_dropout, + forward_attn=c.use_forward_attn, + trans_agent=c.transition_agent, + forward_attn_mask=c.forward_attn_mask, + location_attn=c.location_attn, + attn_K=c.attention_heads, + separate_stopnet=c.separate_stopnet, + bidirectional_decoder=c.bidirectional_decoder, + double_decoder_consistency=c.double_decoder_consistency, + ddc_r=c.ddc_r, + speaker_embedding_dim=speaker_embedding_dim, + ) elif c.model.lower() == "tacotron2": - model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), - num_speakers=num_speakers, - r=c.r, - postnet_output_dim=c.audio['num_mels'], - decoder_output_dim=c.audio['num_mels'], - gst=c.use_gst, - gst_embedding_dim=c.gst['gst_embedding_dim'], - gst_num_heads=c.gst['gst_num_heads'], - gst_style_tokens=c.gst['gst_style_tokens'], - gst_use_speaker_embedding=c.gst['gst_use_speaker_embedding'], - attn_type=c.attention_type, - attn_win=c.windowing, - attn_norm=c.attention_norm, - prenet_type=c.prenet_type, - prenet_dropout=c.prenet_dropout, - forward_attn=c.use_forward_attn, - trans_agent=c.transition_agent, - forward_attn_mask=c.forward_attn_mask, - location_attn=c.location_attn, - attn_K=c.attention_heads, - separate_stopnet=c.separate_stopnet, - bidirectional_decoder=c.bidirectional_decoder, - double_decoder_consistency=c.double_decoder_consistency, - ddc_r=c.ddc_r, - speaker_embedding_dim=speaker_embedding_dim) + model = MyModel( + num_chars=num_chars + getattr(c, "add_blank", False), + num_speakers=num_speakers, + r=c.r, + postnet_output_dim=c.audio["num_mels"], + decoder_output_dim=c.audio["num_mels"], + gst=c.use_gst, + gst_embedding_dim=c.gst["gst_embedding_dim"], + gst_num_heads=c.gst["gst_num_heads"], + gst_style_tokens=c.gst["gst_style_tokens"], + gst_use_speaker_embedding=c.gst["gst_use_speaker_embedding"], + attn_type=c.attention_type, + attn_win=c.windowing, + attn_norm=c.attention_norm, + prenet_type=c.prenet_type, + prenet_dropout=c.prenet_dropout, + forward_attn=c.use_forward_attn, + trans_agent=c.transition_agent, + forward_attn_mask=c.forward_attn_mask, + location_attn=c.location_attn, + attn_K=c.attention_heads, + separate_stopnet=c.separate_stopnet, + bidirectional_decoder=c.bidirectional_decoder, + double_decoder_consistency=c.double_decoder_consistency, + ddc_r=c.ddc_r, + speaker_embedding_dim=speaker_embedding_dim, + ) elif c.model.lower() == "glow_tts": - model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), - hidden_channels_enc=c['hidden_channels_encoder'], - hidden_channels_dec=c['hidden_channels_decoder'], - hidden_channels_dp=c['hidden_channels_duration_predictor'], - out_channels=c.audio['num_mels'], - encoder_type=c.encoder_type, - encoder_params=c.encoder_params, - use_encoder_prenet=c["use_encoder_prenet"], - num_flow_blocks_dec=12, - kernel_size_dec=5, - dilation_rate=1, - num_block_layers=4, - dropout_p_dec=0.05, - num_speakers=num_speakers, - c_in_channels=0, - num_splits=4, - num_squeeze=2, - sigmoid_scale=False, - mean_only=True, - external_speaker_embedding_dim=speaker_embedding_dim) + model = MyModel( + num_chars=num_chars + getattr(c, "add_blank", False), + hidden_channels_enc=c["hidden_channels_encoder"], + hidden_channels_dec=c["hidden_channels_decoder"], + hidden_channels_dp=c["hidden_channels_duration_predictor"], + out_channels=c.audio["num_mels"], + encoder_type=c.encoder_type, + encoder_params=c.encoder_params, + use_encoder_prenet=c["use_encoder_prenet"], + num_flow_blocks_dec=12, + kernel_size_dec=5, + dilation_rate=1, + num_block_layers=4, + dropout_p_dec=0.05, + num_speakers=num_speakers, + c_in_channels=0, + num_splits=4, + num_squeeze=2, + sigmoid_scale=False, + mean_only=True, + external_speaker_embedding_dim=speaker_embedding_dim, + ) elif c.model.lower() == "speedy_speech": - model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), - out_channels=c.audio['num_mels'], - hidden_channels=c['hidden_channels'], - positional_encoding=c['positional_encoding'], - encoder_type=c['encoder_type'], - encoder_params=c['encoder_params'], - decoder_type=c['decoder_type'], - decoder_params=c['decoder_params'], - c_in_channels=0) + model = MyModel( + num_chars=num_chars + getattr(c, "add_blank", False), + out_channels=c.audio["num_mels"], + hidden_channels=c["hidden_channels"], + positional_encoding=c["positional_encoding"], + encoder_type=c["encoder_type"], + encoder_params=c["encoder_params"], + decoder_type=c["decoder_type"], + decoder_params=c["decoder_params"], + c_in_channels=0, + ) elif c.model.lower() == "align_tts": - model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False), - out_channels=c.audio['num_mels'], - hidden_channels=c['hidden_channels'], - hidden_channels_dp=c['hidden_channels_dp'], - encoder_type=c['encoder_type'], - encoder_params=c['encoder_params'], - decoder_type=c['decoder_type'], - decoder_params=c['decoder_params'], - c_in_channels=0) + model = MyModel( + num_chars=num_chars + getattr(c, "add_blank", False), + out_channels=c.audio["num_mels"], + hidden_channels=c["hidden_channels"], + hidden_channels_dp=c["hidden_channels_dp"], + encoder_type=c["encoder_type"], + encoder_params=c["encoder_params"], + decoder_type=c["decoder_type"], + decoder_params=c["decoder_params"], + c_in_channels=0, + ) return model + def is_tacotron(c): - return 'tacotron' in c['model'].lower() + return "tacotron" in c["model"].lower() + def check_config_tts(c): - check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts', 'speedy_speech', 'align_tts'], restricted=True, val_type=str) - check_argument('run_name', c, restricted=True, val_type=str) - check_argument('run_description', c, val_type=str) + check_argument( + "model", + c, + enum_list=["tacotron", "tacotron2", "glow_tts", "speedy_speech", "align_tts"], + restricted=True, + val_type=str, + ) + check_argument("run_name", c, restricted=True, val_type=str) + check_argument("run_description", c, val_type=str) # AUDIO - check_argument('audio', c, restricted=True, val_type=dict) + check_argument("audio", c, restricted=True, val_type=dict) # audio processing parameters - check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056) - check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058) - check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000) - check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length') - check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length') - check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1) - check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10) - check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000) - check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5) - check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000) + check_argument("num_mels", c["audio"], restricted=True, val_type=int, min_val=10, max_val=2056) + check_argument("fft_size", c["audio"], restricted=True, val_type=int, min_val=128, max_val=4058) + check_argument("sample_rate", c["audio"], restricted=True, val_type=int, min_val=512, max_val=100000) + check_argument( + "frame_length_ms", + c["audio"], + restricted=True, + val_type=float, + min_val=10, + max_val=1000, + alternative="win_length", + ) + check_argument( + "frame_shift_ms", c["audio"], restricted=True, val_type=float, min_val=1, max_val=1000, alternative="hop_length" + ) + check_argument("preemphasis", c["audio"], restricted=True, val_type=float, min_val=0, max_val=1) + check_argument("min_level_db", c["audio"], restricted=True, val_type=int, min_val=-1000, max_val=10) + check_argument("ref_level_db", c["audio"], restricted=True, val_type=int, min_val=0, max_val=1000) + check_argument("power", c["audio"], restricted=True, val_type=float, min_val=1, max_val=5) + check_argument("griffin_lim_iters", c["audio"], restricted=True, val_type=int, min_val=10, max_val=1000) # vocabulary parameters - check_argument('characters', c, restricted=False, val_type=dict) - check_argument('pad', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) - check_argument('eos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) - check_argument('bos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) - check_argument('characters', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) - check_argument('phonemes', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys() and c['use_phonemes'], val_type=str) - check_argument('punctuations', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str) + check_argument("characters", c, restricted=False, val_type=dict) + check_argument( + "pad", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str + ) + check_argument( + "eos", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str + ) + check_argument( + "bos", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str + ) + check_argument( + "characters", + c["characters"] if "characters" in c.keys() else {}, + restricted="characters" in c.keys(), + val_type=str, + ) + check_argument( + "phonemes", + c["characters"] if "characters" in c.keys() else {}, + restricted="characters" in c.keys() and c["use_phonemes"], + val_type=str, + ) + check_argument( + "punctuations", + c["characters"] if "characters" in c.keys() else {}, + restricted="characters" in c.keys(), + val_type=str, + ) # normalization parameters - check_argument('signal_norm', c['audio'], restricted=True, val_type=bool) - check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool) - check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000) - check_argument('clip_norm', c['audio'], restricted=True, val_type=bool) - check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000) - check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0) - check_argument('spec_gain', c['audio'], restricted=True, val_type=[int, float], min_val=1, max_val=100) - check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool) - check_argument('trim_db', c['audio'], restricted=True, val_type=int) + check_argument("signal_norm", c["audio"], restricted=True, val_type=bool) + check_argument("symmetric_norm", c["audio"], restricted=True, val_type=bool) + check_argument("max_norm", c["audio"], restricted=True, val_type=float, min_val=0.1, max_val=1000) + check_argument("clip_norm", c["audio"], restricted=True, val_type=bool) + check_argument("mel_fmin", c["audio"], restricted=True, val_type=float, min_val=0.0, max_val=1000) + check_argument("mel_fmax", c["audio"], restricted=True, val_type=float, min_val=500.0) + check_argument("spec_gain", c["audio"], restricted=True, val_type=[int, float], min_val=1, max_val=100) + check_argument("do_trim_silence", c["audio"], restricted=True, val_type=bool) + check_argument("trim_db", c["audio"], restricted=True, val_type=int) # training parameters - check_argument('batch_size', c, restricted=True, val_type=int, min_val=1) - check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1) - check_argument('r', c, restricted=True, val_type=int, min_val=1) - check_argument('gradual_training', c, restricted=False, val_type=list) - check_argument('mixed_precision', c, restricted=False, val_type=bool) + check_argument("batch_size", c, restricted=True, val_type=int, min_val=1) + check_argument("eval_batch_size", c, restricted=True, val_type=int, min_val=1) + check_argument("r", c, restricted=True, val_type=int, min_val=1) + check_argument("gradual_training", c, restricted=False, val_type=list) + check_argument("mixed_precision", c, restricted=False, val_type=bool) # check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100) # loss parameters - check_argument('loss_masking', c, restricted=True, val_type=bool) - if c['model'].lower() in ['tacotron', 'tacotron2']: - check_argument('decoder_loss_alpha', c, restricted=True, val_type=float, min_val=0) - check_argument('postnet_loss_alpha', c, restricted=True, val_type=float, min_val=0) - check_argument('postnet_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0) - check_argument('decoder_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0) - check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0) - check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0) - check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0) - if c['model'].lower in ["speedy_speech", "align_tts"]: - check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0) - check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0) - check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0) + check_argument("loss_masking", c, restricted=True, val_type=bool) + if c["model"].lower() in ["tacotron", "tacotron2"]: + check_argument("decoder_loss_alpha", c, restricted=True, val_type=float, min_val=0) + check_argument("postnet_loss_alpha", c, restricted=True, val_type=float, min_val=0) + check_argument("postnet_diff_spec_alpha", c, restricted=True, val_type=float, min_val=0) + check_argument("decoder_diff_spec_alpha", c, restricted=True, val_type=float, min_val=0) + check_argument("decoder_ssim_alpha", c, restricted=True, val_type=float, min_val=0) + check_argument("postnet_ssim_alpha", c, restricted=True, val_type=float, min_val=0) + check_argument("ga_alpha", c, restricted=True, val_type=float, min_val=0) + if c["model"].lower in ["speedy_speech", "align_tts"]: + check_argument("ssim_alpha", c, restricted=True, val_type=float, min_val=0) + check_argument("l1_alpha", c, restricted=True, val_type=float, min_val=0) + check_argument("huber_alpha", c, restricted=True, val_type=float, min_val=0) # validation parameters - check_argument('run_eval', c, restricted=True, val_type=bool) - check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0) - check_argument('test_sentences_file', c, restricted=False, val_type=str) + check_argument("run_eval", c, restricted=True, val_type=bool) + check_argument("test_delay_epochs", c, restricted=True, val_type=int, min_val=0) + check_argument("test_sentences_file", c, restricted=False, val_type=str) # optimizer - check_argument('noam_schedule', c, restricted=False, val_type=bool) - check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0) - check_argument('epochs', c, restricted=True, val_type=int, min_val=1) - check_argument('lr', c, restricted=True, val_type=float, min_val=0) - check_argument('wd', c, restricted=is_tacotron(c), val_type=float, min_val=0) - check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0) - check_argument('seq_len_norm', c, restricted=is_tacotron(c), val_type=bool) + check_argument("noam_schedule", c, restricted=False, val_type=bool) + check_argument("grad_clip", c, restricted=True, val_type=float, min_val=0.0) + check_argument("epochs", c, restricted=True, val_type=int, min_val=1) + check_argument("lr", c, restricted=True, val_type=float, min_val=0) + check_argument("wd", c, restricted=is_tacotron(c), val_type=float, min_val=0) + check_argument("warmup_steps", c, restricted=True, val_type=int, min_val=0) + check_argument("seq_len_norm", c, restricted=is_tacotron(c), val_type=bool) # tacotron prenet - check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1) - check_argument('prenet_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['original', 'bn']) - check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool) + check_argument("memory_size", c, restricted=is_tacotron(c), val_type=int, min_val=-1) + check_argument("prenet_type", c, restricted=is_tacotron(c), val_type=str, enum_list=["original", "bn"]) + check_argument("prenet_dropout", c, restricted=is_tacotron(c), val_type=bool) # attention - check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original', 'dynamic_convolution']) - check_argument('attention_heads', c, restricted=is_tacotron(c), val_type=int) - check_argument('attention_norm', c, restricted=is_tacotron(c), val_type=str, enum_list=['sigmoid', 'softmax']) - check_argument('windowing', c, restricted=is_tacotron(c), val_type=bool) - check_argument('use_forward_attn', c, restricted=is_tacotron(c), val_type=bool) - check_argument('forward_attn_mask', c, restricted=is_tacotron(c), val_type=bool) - check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool) - check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool) - check_argument('location_attn', c, restricted=is_tacotron(c), val_type=bool) - check_argument('bidirectional_decoder', c, restricted=is_tacotron(c), val_type=bool) - check_argument('double_decoder_consistency', c, restricted=is_tacotron(c), val_type=bool) - check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int) + check_argument( + "attention_type", + c, + restricted=is_tacotron(c), + val_type=str, + enum_list=["graves", "original", "dynamic_convolution"], + ) + check_argument("attention_heads", c, restricted=is_tacotron(c), val_type=int) + check_argument("attention_norm", c, restricted=is_tacotron(c), val_type=str, enum_list=["sigmoid", "softmax"]) + check_argument("windowing", c, restricted=is_tacotron(c), val_type=bool) + check_argument("use_forward_attn", c, restricted=is_tacotron(c), val_type=bool) + check_argument("forward_attn_mask", c, restricted=is_tacotron(c), val_type=bool) + check_argument("transition_agent", c, restricted=is_tacotron(c), val_type=bool) + check_argument("transition_agent", c, restricted=is_tacotron(c), val_type=bool) + check_argument("location_attn", c, restricted=is_tacotron(c), val_type=bool) + check_argument("bidirectional_decoder", c, restricted=is_tacotron(c), val_type=bool) + check_argument("double_decoder_consistency", c, restricted=is_tacotron(c), val_type=bool) + check_argument("ddc_r", c, restricted="double_decoder_consistency" in c.keys(), min_val=1, max_val=7, val_type=int) - if c['model'].lower() in ['tacotron', 'tacotron2']: + if c["model"].lower() in ["tacotron", "tacotron2"]: # stopnet - check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool) - check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool) + check_argument("stopnet", c, restricted=is_tacotron(c), val_type=bool) + check_argument("separate_stopnet", c, restricted=is_tacotron(c), val_type=bool) # Model Parameters for non-tacotron models - if c['model'].lower in ["speedy_speech", "align_tts"]: - check_argument('positional_encoding', c, restricted=True, val_type=type) - check_argument('encoder_type', c, restricted=True, val_type=str) - check_argument('encoder_params', c, restricted=True, val_type=dict) - check_argument('decoder_residual_conv_bn_params', c, restricted=True, val_type=dict) + if c["model"].lower in ["speedy_speech", "align_tts"]: + check_argument("positional_encoding", c, restricted=True, val_type=type) + check_argument("encoder_type", c, restricted=True, val_type=str) + check_argument("encoder_params", c, restricted=True, val_type=dict) + check_argument("decoder_residual_conv_bn_params", c, restricted=True, val_type=dict) # GlowTTS parameters - check_argument('encoder_type', c, restricted=not is_tacotron(c), val_type=str) + check_argument("encoder_type", c, restricted=not is_tacotron(c), val_type=str) # tensorboard - check_argument('print_step', c, restricted=True, val_type=int, min_val=1) - check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1) - check_argument('save_step', c, restricted=True, val_type=int, min_val=1) - check_argument('checkpoint', c, restricted=True, val_type=bool) - check_argument('tb_model_param_stats', c, restricted=True, val_type=bool) + check_argument("print_step", c, restricted=True, val_type=int, min_val=1) + check_argument("tb_plot_step", c, restricted=True, val_type=int, min_val=1) + check_argument("save_step", c, restricted=True, val_type=int, min_val=1) + check_argument("checkpoint", c, restricted=True, val_type=bool) + check_argument("tb_model_param_stats", c, restricted=True, val_type=bool) # dataloading # pylint: disable=import-outside-toplevel from TTS.tts.utils.text import cleaners - check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners)) - check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool) - check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0) - check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0) - check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0) - check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0) - check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10) - check_argument('compute_input_seq_cache', c, restricted=True, val_type=bool) + + check_argument("text_cleaner", c, restricted=True, val_type=str, enum_list=dir(cleaners)) + check_argument("enable_eos_bos_chars", c, restricted=True, val_type=bool) + check_argument("num_loader_workers", c, restricted=True, val_type=int, min_val=0) + check_argument("num_val_loader_workers", c, restricted=True, val_type=int, min_val=0) + check_argument("batch_group_size", c, restricted=True, val_type=int, min_val=0) + check_argument("min_seq_len", c, restricted=True, val_type=int, min_val=0) + check_argument("max_seq_len", c, restricted=True, val_type=int, min_val=10) + check_argument("compute_input_seq_cache", c, restricted=True, val_type=bool) # paths - check_argument('output_path', c, restricted=True, val_type=str) + check_argument("output_path", c, restricted=True, val_type=str) # multi-speaker and gst - check_argument('use_speaker_embedding', c, restricted=True, val_type=bool) - check_argument('use_external_speaker_embedding_file', c, restricted=c['use_speaker_embedding'], val_type=bool) - check_argument('external_speaker_embedding_file', c, restricted=c['use_external_speaker_embedding_file'], val_type=str) - if c['model'].lower() in ['tacotron', 'tacotron2'] and c['use_gst']: - check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool) - check_argument('gst', c, restricted=is_tacotron(c), val_type=dict) - check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict]) - check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000) - check_argument('gst_use_speaker_embedding', c['gst'], restricted=is_tacotron(c), val_type=bool) - check_argument('gst_num_heads', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10) - check_argument('gst_style_tokens', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000) + check_argument("use_speaker_embedding", c, restricted=True, val_type=bool) + check_argument("use_external_speaker_embedding_file", c, restricted=c["use_speaker_embedding"], val_type=bool) + check_argument( + "external_speaker_embedding_file", c, restricted=c["use_external_speaker_embedding_file"], val_type=str + ) + if c["model"].lower() in ["tacotron", "tacotron2"] and c["use_gst"]: + check_argument("use_gst", c, restricted=is_tacotron(c), val_type=bool) + check_argument("gst", c, restricted=is_tacotron(c), val_type=dict) + check_argument("gst_style_input", c["gst"], restricted=is_tacotron(c), val_type=[str, dict]) + check_argument("gst_embedding_dim", c["gst"], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000) + check_argument("gst_use_speaker_embedding", c["gst"], restricted=is_tacotron(c), val_type=bool) + check_argument("gst_num_heads", c["gst"], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10) + check_argument("gst_style_tokens", c["gst"], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000) # datasets - checking only the first entry - check_argument('datasets', c, restricted=True, val_type=list) - for dataset_entry in c['datasets']: - check_argument('name', dataset_entry, restricted=True, val_type=str) - check_argument('path', dataset_entry, restricted=True, val_type=str) - check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list]) - check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str) + check_argument("datasets", c, restricted=True, val_type=list) + for dataset_entry in c["datasets"]: + check_argument("name", dataset_entry, restricted=True, val_type=str) + check_argument("path", dataset_entry, restricted=True, val_type=str) + check_argument("meta_file_train", dataset_entry, restricted=True, val_type=[str, list]) + check_argument("meta_file_val", dataset_entry, restricted=True, val_type=str) diff --git a/TTS/tts/utils/io.py b/TTS/tts/utils/io.py index bcf5ff37..32b5c766 100644 --- a/TTS/tts/utils/io.py +++ b/TTS/tts/utils/io.py @@ -6,7 +6,6 @@ import pickle as pickle_tts from TTS.utils.io import RenamingUnpickler - def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False): # pylint: disable=redefined-builtin """Load ```TTS.tts.models``` checkpoints. @@ -20,33 +19,25 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False [type]: [description] """ try: - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) except ModuleNotFoundError: pickle_tts.Unpickler = RenamingUnpickler - state = torch.load(checkpoint_path, map_location=torch.device('cpu'), pickle_module=pickle_tts) - model.load_state_dict(state['model']) - if amp and 'amp' in state: - amp.load_state_dict(state['amp']) + state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts) + model.load_state_dict(state["model"]) + if amp and "amp" in state: + amp.load_state_dict(state["amp"]) if use_cuda: model.cuda() # set model stepsize - if hasattr(model.decoder, 'r'): - model.decoder.set_r(state['r']) - print(" > Model r: ", state['r']) + if hasattr(model.decoder, "r"): + model.decoder.set_r(state["r"]) + print(" > Model r: ", state["r"]) if eval: model.eval() return model, state -def save_model(model, - optimizer, - current_step, - epoch, - r, - output_path, - characters, - amp_state_dict=None, - **kwargs): +def save_model(model, optimizer, current_step, epoch, r, output_path, characters, amp_state_dict=None, **kwargs): """Save ```TTS.tts.models``` states with extra fields. Args: @@ -59,27 +50,26 @@ def save_model(model, characters (list): list of characters used in the model. amp_state_dict (state_dict, optional): Apex.amp state dict if Apex is enabled. Defaults to None. """ - if hasattr(model, 'module'): + if hasattr(model, "module"): model_state = model.module.state_dict() else: model_state = model.state_dict() state = { - 'model': model_state, - 'optimizer': optimizer.state_dict() if optimizer is not None else None, - 'step': current_step, - 'epoch': epoch, - 'date': datetime.date.today().strftime("%B %d, %Y"), - 'r': r, - 'characters': characters + "model": model_state, + "optimizer": optimizer.state_dict() if optimizer is not None else None, + "step": current_step, + "epoch": epoch, + "date": datetime.date.today().strftime("%B %d, %Y"), + "r": r, + "characters": characters, } if amp_state_dict: - state['amp'] = amp_state_dict + state["amp"] = amp_state_dict state.update(kwargs) torch.save(state, output_path) -def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, - characters, **kwargs): +def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, characters, **kwargs): """Save model checkpoint, intended for saving checkpoints at training. Args: @@ -91,14 +81,15 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, output_path (str): output path to save the model file. characters (list): list of characters used in the model. """ - file_name = 'checkpoint_{}.pth.tar'.format(current_step) + file_name = "checkpoint_{}.pth.tar".format(current_step) checkpoint_path = os.path.join(output_folder, file_name) print(" > CHECKPOINT : {}".format(checkpoint_path)) save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, **kwargs) -def save_best_model(target_loss, best_loss, model, optimizer, current_step, - epoch, r, output_folder, characters, **kwargs): +def save_best_model( + target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, characters, **kwargs +): """Save model checkpoint, intended for saving the best model after each epoch. It compares the current model loss with the best loss so far and saves the model if the current loss is better. @@ -118,9 +109,11 @@ def save_best_model(target_loss, best_loss, model, optimizer, current_step, float: updated current best loss. """ if target_loss < best_loss: - file_name = 'best_model.pth.tar' + file_name = "best_model.pth.tar" checkpoint_path = os.path.join(output_folder, file_name) print(" >> BEST MODEL : {}".format(checkpoint_path)) - save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, model_loss=target_loss, **kwargs) + save_model( + model, optimizer, current_step, epoch, r, checkpoint_path, characters, model_loss=target_loss, **kwargs + ) best_loss = target_loss return best_loss diff --git a/TTS/tts/utils/speakers.py b/TTS/tts/utils/speakers.py index feb1a845..224667dd 100644 --- a/TTS/tts/utils/speakers.py +++ b/TTS/tts/utils/speakers.py @@ -10,7 +10,7 @@ def make_speakers_json_path(out_path): def load_speaker_mapping(out_path): """Loads speaker mapping if already present.""" try: - if os.path.splitext(out_path)[1] == '.json': + if os.path.splitext(out_path)[1] == ".json": json_file = out_path else: json_file = make_speakers_json_path(out_path) @@ -19,6 +19,7 @@ def load_speaker_mapping(out_path): except FileNotFoundError: return {} + def save_speaker_mapping(out_path, speaker_mapping): """Saves speaker mapping if not yet present.""" speakers_json_path = make_speakers_json_path(out_path) @@ -31,40 +32,49 @@ def get_speakers(items): speakers = {e[2] for e in items} return sorted(speakers) + def parse_speakers(c, args, meta_data_train, OUT_PATH): """ Returns number of speakers, speaker embedding shape and speaker mapping""" if c.use_speaker_embedding: speakers = get_speakers(meta_data_train) if args.restore_path: - if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file + if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file prev_out_path = os.path.dirname(args.restore_path) speaker_mapping = load_speaker_mapping(prev_out_path) if not speaker_mapping: - print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file") + print( + "WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file" + ) speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file) if not speaker_mapping: - raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file") - speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding']) - elif not c.use_external_speaker_embedding_file: # if restore checkpoint and don't use External Embedding file + raise RuntimeError( + "You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file" + ) + speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]["embedding"]) + elif ( + not c.use_external_speaker_embedding_file + ): # if restore checkpoint and don't use External Embedding file prev_out_path = os.path.dirname(args.restore_path) speaker_mapping = load_speaker_mapping(prev_out_path) speaker_embedding_dim = None - assert all([speaker in speaker_mapping - for speaker in speakers]), "As of now you, you cannot " \ - "introduce new speakers to " \ - "a previously trained model." - elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: # if start new train using External Embedding file + assert all(speaker in speaker_mapping for speaker in speakers), ( + "As of now you, you cannot " "introduce new speakers to " "a previously trained model." + ) + elif ( + c.use_external_speaker_embedding_file and c.external_speaker_embedding_file + ): # if start new train using External Embedding file speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file) - speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding']) - elif c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file: # if start new train using External Embedding file and don't pass external embedding file + speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]["embedding"]) + elif ( + c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file + ): # if start new train using External Embedding file and don't pass external embedding file raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder" - else: # if start new train and don't use External Embedding file + else: # if start new train and don't use External Embedding file speaker_mapping = {name: i for i, name in enumerate(speakers)} speaker_embedding_dim = None save_speaker_mapping(OUT_PATH, speaker_mapping) num_speakers = len(speaker_mapping) - print(" > Training with {} speakers: {}".format( - len(speakers), ", ".join(speakers))) + print(" > Training with {} speakers: {}".format(len(speakers), ", ".join(speakers))) else: num_speakers = 0 speaker_embedding_dim = None diff --git a/TTS/tts/utils/ssim.py b/TTS/tts/utils/ssim.py index 8f4c4cae..11107e47 100644 --- a/TTS/tts/utils/ssim.py +++ b/TTS/tts/utils/ssim.py @@ -8,8 +8,9 @@ from torch.autograd import Variable def gaussian(window_size, sigma): - gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) - return gauss/gauss.sum() + gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + def create_window(window_size, channel): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) @@ -24,25 +25,22 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True): mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) - mu1_mu2 = mu1*mu2 + mu1_mu2 = mu1 * mu2 - sigma1_sq = F.conv2d( - img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq - sigma2_sq = F.conv2d( - img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq - sigma12 = F.conv2d( - img1 * img2, window, padding=window_size // 2, - groups=channel) - mu1_mu2 + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 - C1 = 0.01**2 - C2 = 0.03**2 + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 - ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) if size_average: return ssim_map.mean() return ssim_map.mean(1).mean(1).mean(1) + class SSIM(torch.nn.Module): def __init__(self, window_size=11, size_average=True): super().__init__() @@ -66,7 +64,6 @@ class SSIM(torch.nn.Module): self.window = window self.channel = channel - return _ssim(img1, img2, window, self.window_size, channel, self.size_average) diff --git a/TTS/tts/utils/synthesis.py b/TTS/tts/utils/synthesis.py index f825d61c..4621961f 100644 --- a/TTS/tts/utils/synthesis.py +++ b/TTS/tts/utils/synthesis.py @@ -1,8 +1,10 @@ import os -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" import pkg_resources -installed = {pkg.key for pkg in pkg_resources.working_set} #pylint: disable=not-an-iterable -if 'tensorflow' in installed or 'tensorflow-gpu' in installed: + +installed = {pkg.key for pkg in pkg_resources.working_set} # pylint: disable=not-an-iterable +if "tensorflow" in installed or "tensorflow-gpu" in installed: import tensorflow as tf import torch import numpy as np @@ -14,19 +16,26 @@ def text_to_seqvec(text, CONFIG): # text ot phonemes to sequence vector if CONFIG.use_phonemes: seq = np.asarray( - phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language, - CONFIG.enable_eos_bos_chars, - tp=CONFIG.characters if 'characters' in CONFIG.keys() else None, - add_blank=CONFIG['add_blank'] if 'add_blank' in CONFIG.keys() else False), - dtype=np.int32) + phoneme_to_sequence( + text, + text_cleaner, + CONFIG.phoneme_language, + CONFIG.enable_eos_bos_chars, + tp=CONFIG.characters if "characters" in CONFIG.keys() else None, + add_blank=CONFIG["add_blank"] if "add_blank" in CONFIG.keys() else False, + ), + dtype=np.int32, + ) else: - seq = np.asarray(text_to_sequence( - text, - text_cleaner, - tp=CONFIG.characters if 'characters' in CONFIG.keys() else None, - add_blank=CONFIG['add_blank'] - if 'add_blank' in CONFIG.keys() else False), - dtype=np.int32) + seq = np.asarray( + text_to_sequence( + text, + text_cleaner, + tp=CONFIG.characters if "characters" in CONFIG.keys() else None, + add_blank=CONFIG["add_blank"] if "add_blank" in CONFIG.keys() else False, + ), + dtype=np.int32, + ) return seq @@ -47,86 +56,95 @@ def numpy_to_tf(np_array, dtype): def compute_style_mel(style_wav, ap, cuda=False): - style_mel = torch.FloatTensor(ap.melspectrogram( - ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0) + style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0) if cuda: return style_mel.cuda() return style_mel def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None): - if 'tacotron' in CONFIG.model.lower(): + if "tacotron" in CONFIG.model.lower(): if CONFIG.use_gst: decoder_output, postnet_output, alignments, stop_tokens = model.inference( - inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) + inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings + ) else: if truncated: decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated( - inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) + inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings + ) else: decoder_output, postnet_output, alignments, stop_tokens = model.inference( - inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings) - elif 'glow' in CONFIG.model.lower(): + inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings + ) + elif "glow" in CONFIG.model.lower(): inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable - if hasattr(model, 'module'): + if hasattr(model, "module"): # distributed model - postnet_output, _, _, _, alignments, _, _ = model.module.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings) + postnet_output, _, _, _, alignments, _, _ = model.module.inference( + inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings + ) else: - postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings) + postnet_output, _, _, _, alignments, _, _ = model.inference( + inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings + ) postnet_output = postnet_output.permute(0, 2, 1) # these only belong to tacotron models. decoder_output = None stop_tokens = None - elif CONFIG.model.lower() in ['speedy_speech', 'align_tts']: + elif CONFIG.model.lower() in ["speedy_speech", "align_tts"]: inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable - if hasattr(model, 'module'): + if hasattr(model, "module"): # distributed model - postnet_output, alignments = model.module.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings) + postnet_output, alignments = model.module.inference( + inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings + ) else: - postnet_output, alignments = model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings) + postnet_output, alignments = model.inference( + inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings + ) postnet_output = postnet_output.permute(0, 2, 1) # these only belong to tacotron models. decoder_output = None stop_tokens = None else: - raise ValueError('[!] Unknown model name.') + raise ValueError("[!] Unknown model name.") return decoder_output, postnet_output, alignments, stop_tokens def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): if CONFIG.use_gst and style_mel is not None: - raise NotImplementedError(' [!] GST inference not implemented for TF') + raise NotImplementedError(" [!] GST inference not implemented for TF") if truncated: - raise NotImplementedError(' [!] Truncated inference not implemented for TF') + raise NotImplementedError(" [!] Truncated inference not implemented for TF") if speaker_id is not None: - raise NotImplementedError(' [!] Multi-Speaker not implemented for TF') + raise NotImplementedError(" [!] Multi-Speaker not implemented for TF") # TODO: handle multispeaker case - decoder_output, postnet_output, alignments, stop_tokens = model( - inputs, training=False) + decoder_output, postnet_output, alignments, stop_tokens = model(inputs, training=False) return decoder_output, postnet_output, alignments, stop_tokens def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None): if CONFIG.use_gst and style_mel is not None: - raise NotImplementedError(' [!] GST inference not implemented for TfLite') + raise NotImplementedError(" [!] GST inference not implemented for TfLite") if truncated: - raise NotImplementedError(' [!] Truncated inference not implemented for TfLite') + raise NotImplementedError(" [!] Truncated inference not implemented for TfLite") if speaker_id is not None: - raise NotImplementedError(' [!] Multi-Speaker not implemented for TfLite') + raise NotImplementedError(" [!] Multi-Speaker not implemented for TfLite") # get input and output details input_details = model.get_input_details() output_details = model.get_output_details() # reshape input tensor for the new input shape - model.resize_tensor_input(input_details[0]['index'], inputs.shape) + model.resize_tensor_input(input_details[0]["index"], inputs.shape) model.allocate_tensors() detail = input_details[0] # input_shape = detail['shape'] - model.set_tensor(detail['index'], inputs) + model.set_tensor(detail["index"], inputs) # run the model model.invoke() # collect outputs - decoder_output = model.get_tensor(output_details[0]['index']) - postnet_output = model.get_tensor(output_details[1]['index']) + decoder_output = model.get_tensor(output_details[0]["index"]) + postnet_output = model.get_tensor(output_details[1]["index"]) # tflite model only returns feature frames return decoder_output, postnet_output, None, None @@ -154,7 +172,7 @@ def parse_outputs_tflite(postnet_output, decoder_output): def trim_silence(wav, ap): - return wav[:ap.find_endpoint(wav)] + return wav[: ap.find_endpoint(wav)] def inv_spectrogram(postnet_output, ap, CONFIG): @@ -186,13 +204,13 @@ def embedding_to_torch(speaker_embedding, cuda=False): # TODO: perform GL with pytorch for batching def apply_griffin_lim(inputs, input_lens, CONFIG, ap): - '''Apply griffin-lim to each sample iterating throught the first dimension. + """Apply griffin-lim to each sample iterating throught the first dimension. Args: inputs (Tensor or np.Array): Features to be converted by GL. First dimension is the batch size. input_lens (Tensor or np.Array): 1D array of sample lengths. CONFIG (Dict): TTS config. ap (AudioProcessor): TTS audio processor. - ''' + """ wavs = [] for idx, spec in enumerate(inputs): wav_len = (input_lens[idx] * ap.hop_length) - ap.hop_length # inverse librosa padding @@ -202,39 +220,41 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap): return wavs -def synthesis(model, - text, - CONFIG, - use_cuda, - ap, - speaker_id=None, - style_wav=None, - truncated=False, - enable_eos_bos_chars=False, #pylint: disable=unused-argument - use_griffin_lim=False, - do_trim_silence=False, - speaker_embedding=None, - backend='torch'): +def synthesis( + model, + text, + CONFIG, + use_cuda, + ap, + speaker_id=None, + style_wav=None, + truncated=False, + enable_eos_bos_chars=False, # pylint: disable=unused-argument + use_griffin_lim=False, + do_trim_silence=False, + speaker_embedding=None, + backend="torch", +): """Synthesize voice for the given text. - Args: - model (TTS.tts.models): model to synthesize. - text (str): target text - CONFIG (dict): config dictionary to be loaded from config.json. - use_cuda (bool): enable cuda. - ap (TTS.tts.utils.audio.AudioProcessor): audio processor to process - model outputs. - speaker_id (int): id of speaker - style_wav (str | Dict[str, float]): Uses for style embedding of GST. - truncated (bool): keep model states after inference. It can be used - for continuous inference at long texts. - enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence. - do_trim_silence (bool): trim silence after synthesis. - backend (str): tf or torch + Args: + model (TTS.tts.models): model to synthesize. + text (str): target text + CONFIG (dict): config dictionary to be loaded from config.json. + use_cuda (bool): enable cuda. + ap (TTS.tts.utils.audio.AudioProcessor): audio processor to process + model outputs. + speaker_id (int): id of speaker + style_wav (str | Dict[str, float]): Uses for style embedding of GST. + truncated (bool): keep model states after inference. It can be used + for continuous inference at long texts. + enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence. + do_trim_silence (bool): trim silence after synthesis. + backend (str): tf or torch """ # GST processing style_mel = None - if 'use_gst' in CONFIG.keys() and CONFIG.use_gst and style_wav is not None: + if "use_gst" in CONFIG.keys() and CONFIG.use_gst and style_wav is not None: if isinstance(style_wav, dict): style_mel = style_wav else: @@ -242,7 +262,7 @@ def synthesis(model, # preprocess the given text inputs = text_to_seqvec(text, CONFIG) # pass tensors to backend - if backend == 'torch': + if backend == "torch": if speaker_id is not None: speaker_id = id_to_torch(speaker_id, cuda=use_cuda) @@ -253,31 +273,35 @@ def synthesis(model, style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda) inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda) inputs = inputs.unsqueeze(0) - elif backend == 'tf': + elif backend == "tf": # TODO: handle speaker id for tf model style_mel = numpy_to_tf(style_mel, tf.float32) inputs = numpy_to_tf(inputs, tf.int32) inputs = tf.expand_dims(inputs, 0) - elif backend == 'tflite': + elif backend == "tflite": style_mel = numpy_to_tf(style_mel, tf.float32) inputs = numpy_to_tf(inputs, tf.int32) inputs = tf.expand_dims(inputs, 0) # synthesize voice - if backend == 'torch': + if backend == "torch": decoder_output, postnet_output, alignments, stop_tokens = run_model_torch( - model, inputs, CONFIG, truncated, speaker_id, style_mel, speaker_embeddings=speaker_embedding) + model, inputs, CONFIG, truncated, speaker_id, style_mel, speaker_embeddings=speaker_embedding + ) postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch( - postnet_output, decoder_output, alignments, stop_tokens) - elif backend == 'tf': + postnet_output, decoder_output, alignments, stop_tokens + ) + elif backend == "tf": decoder_output, postnet_output, alignments, stop_tokens = run_model_tf( - model, inputs, CONFIG, truncated, speaker_id, style_mel) + model, inputs, CONFIG, truncated, speaker_id, style_mel + ) postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf( - postnet_output, decoder_output, alignments, stop_tokens) - elif backend == 'tflite': + postnet_output, decoder_output, alignments, stop_tokens + ) + elif backend == "tflite": decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite( - model, inputs, CONFIG, truncated, speaker_id, style_mel) - postnet_output, decoder_output = parse_outputs_tflite( - postnet_output, decoder_output) + model, inputs, CONFIG, truncated, speaker_id, style_mel + ) + postnet_output, decoder_output = parse_outputs_tflite(postnet_output, decoder_output) # convert outputs to numpy # plot results wav = None diff --git a/TTS/tts/utils/text/__init__.py b/TTS/tts/utils/text/__init__.py index 2a724650..6d1dc9a0 100644 --- a/TTS/tts/utils/text/__init__.py +++ b/TTS/tts/utils/text/__init__.py @@ -6,8 +6,7 @@ import phonemizer from packaging import version from phonemizer.phonemize import phonemize from TTS.tts.utils.text import cleaners -from TTS.tts.utils.text.symbols import (_bos, _eos, _punctuations, - make_symbols, phonemes, symbols) +from TTS.tts.utils.text.symbols import _bos, _eos, _punctuations, make_symbols, phonemes, symbols from TTS.tts.utils.chinese_mandarin.phonemizer import chinese_text_to_phonemes @@ -22,14 +21,14 @@ _id_to_phonemes = {i: s for i, s in enumerate(phonemes)} _symbols = symbols _phonemes = phonemes # Regular expression matching text enclosed in curly braces: -_CURLY_RE = re.compile(r'(.*?)\{(.+?)\}(.*)') +_CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)") # Regular expression matching punctuations, ignoring empty space -PHONEME_PUNCTUATION_PATTERN = r'['+_punctuations.replace(' ', '')+']+' +PHONEME_PUNCTUATION_PATTERN = r"[" + _punctuations.replace(" ", "") + "]+" def text2phone(text, language): - '''Convert graphemes to phonemes. For most of the languages, it calls + """Convert graphemes to phonemes. For most of the languages, it calls the phonemizer python library that calls espeak/espeak-ng. For chinese mandarin, it calls pypinyin + custom function for phonemizing Parameters: @@ -38,60 +37,73 @@ def text2phone(text, language): Returns: ph (str): phonemes as a string seperated by "|" ph = "ɪ|g|ˈ|z|æ|m|p|ə|l" - ''' + """ # TO REVIEW : How to have a good implementation for this? if language == "zh-CN": ph = chinese_text_to_phonemes(text) return ph - - seperator = phonemizer.separator.Separator(' |', '', '|') - #try: + seperator = phonemizer.separator.Separator(" |", "", "|") + # try: punctuations = re.findall(PHONEME_PUNCTUATION_PATTERN, text) - if version.parse(phonemizer.__version__) < version.parse('2.1'): - ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language) - ph = ph[:-1].strip() # skip the last empty character + if version.parse(phonemizer.__version__) < version.parse("2.1"): + ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend="espeak", language=language) + ph = ph[:-1].strip() # skip the last empty character # phonemizer does not tackle punctuations. Here we do. # Replace \n with matching punctuations. if punctuations: # if text ends with a punctuation. if text[-1] == punctuations[-1]: for punct in punctuations[:-1]: - ph = ph.replace('| |\n', '|'+punct+'| |', 1) + ph = ph.replace("| |\n", "|" + punct + "| |", 1) ph = ph + punctuations[-1] else: for punct in punctuations: - ph = ph.replace('| |\n', '|'+punct+'| |', 1) - elif version.parse(phonemizer.__version__) >= version.parse('2.1'): - ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language, preserve_punctuation=True, language_switch='remove-flags') + ph = ph.replace("| |\n", "|" + punct + "| |", 1) + elif version.parse(phonemizer.__version__) >= version.parse("2.1"): + ph = phonemize( + text, + separator=seperator, + strip=False, + njobs=1, + backend="espeak", + language=language, + preserve_punctuation=True, + language_switch="remove-flags", + ) # this is a simple fix for phonemizer. # https://github.com/bootphon/phonemizer/issues/32 if punctuations: for punctuation in punctuations: - ph = ph.replace(f"| |{punctuation} ", f"|{punctuation}| |").replace(f"| |{punctuation}", f"|{punctuation}| |") + ph = ph.replace(f"| |{punctuation} ", f"|{punctuation}| |").replace( + f"| |{punctuation}", f"|{punctuation}| |" + ) ph = ph[:-3] else: raise RuntimeError(" [!] Use 'phonemizer' version 2.1 or older.") return ph + def intersperse(sequence, token): result = [token] * (len(sequence) * 2 + 1) result[1::2] = sequence return result + def pad_with_eos_bos(phoneme_sequence, tp=None): # pylint: disable=global-statement global _phonemes_to_id, _bos, _eos if tp: - _bos = tp['bos'] - _eos = tp['eos'] + _bos = tp["bos"] + _eos = tp["eos"] _, _phonemes = make_symbols(**tp) _phonemes_to_id = {s: i for i, s in enumerate(_phonemes)} return [_phonemes_to_id[_bos]] + list(phoneme_sequence) + [_phonemes_to_id[_eos]] + def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=None, add_blank=False): # pylint: disable=global-statement global _phonemes_to_id, _phonemes @@ -105,23 +117,23 @@ def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp= if to_phonemes is None: print("!! After phoneme conversion the result is None. -- {} ".format(clean_text)) # iterate by skipping empty strings - NOTE: might be useful to keep it to have a better intonation. - for phoneme in filter(None, to_phonemes.split('|')): + for phoneme in filter(None, to_phonemes.split("|")): sequence += _phoneme_to_sequence(phoneme) # Append EOS char if enable_eos_bos: sequence = pad_with_eos_bos(sequence, tp=tp) if add_blank: - sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes) + sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes) return sequence def sequence_to_phoneme(sequence, tp=None, add_blank=False): # pylint: disable=global-statement - '''Converts a sequence of IDs back to a string''' + """Converts a sequence of IDs back to a string""" global _id_to_phonemes, _phonemes if add_blank: sequence = list(filter(lambda x: x != len(_phonemes), sequence)) - result = '' + result = "" if tp: _, _phonemes = make_symbols(**tp) _id_to_phonemes = {i: s for i, s in enumerate(_phonemes)} @@ -130,22 +142,22 @@ def sequence_to_phoneme(sequence, tp=None, add_blank=False): if symbol_id in _id_to_phonemes: s = _id_to_phonemes[symbol_id] result += s - return result.replace('}{', ' ') + return result.replace("}{", " ") def text_to_sequence(text, cleaner_names, tp=None, add_blank=False): - '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. - The text can optionally have ARPAbet sequences enclosed in curly braces embedded - in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." + The text can optionally have ARPAbet sequences enclosed in curly braces embedded + in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." - Args: - text: string to convert to a sequence - cleaner_names: names of the cleaner functions to run the text through + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through - Returns: - List of integers corresponding to the symbols in the text - ''' + Returns: + List of integers corresponding to the symbols in the text + """ # pylint: disable=global-statement global _symbol_to_id, _symbols if tp: @@ -159,18 +171,17 @@ def text_to_sequence(text, cleaner_names, tp=None, add_blank=False): if not m: sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) break - sequence += _symbols_to_sequence( - _clean_text(m.group(1), cleaner_names)) + sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) sequence += _arpabet_to_sequence(m.group(2)) text = m.group(3) if add_blank: - sequence = intersperse(sequence, len(_symbols)) # add a blank token (new), whose id number is len(_symbols) + sequence = intersperse(sequence, len(_symbols)) # add a blank token (new), whose id number is len(_symbols) return sequence def sequence_to_text(sequence, tp=None, add_blank=False): - '''Converts a sequence of IDs back to a string''' + """Converts a sequence of IDs back to a string""" # pylint: disable=global-statement global _id_to_symbol, _symbols if add_blank: @@ -180,22 +191,22 @@ def sequence_to_text(sequence, tp=None, add_blank=False): _symbols, _ = make_symbols(**tp) _id_to_symbol = {i: s for i, s in enumerate(_symbols)} - result = '' + result = "" for symbol_id in sequence: if symbol_id in _id_to_symbol: s = _id_to_symbol[symbol_id] # Enclose ARPAbet back in curly braces: - if len(s) > 1 and s[0] == '@': - s = '{%s}' % s[1:] + if len(s) > 1 and s[0] == "@": + s = "{%s}" % s[1:] result += s - return result.replace('}{', ' ') + return result.replace("}{", " ") def _clean_text(text, cleaner_names): for name in cleaner_names: cleaner = getattr(cleaners, name) if not cleaner: - raise Exception('Unknown cleaner: %s' % name) + raise Exception("Unknown cleaner: %s" % name) text = cleaner(text) return text @@ -209,12 +220,12 @@ def _phoneme_to_sequence(phons): def _arpabet_to_sequence(text): - return _symbols_to_sequence(['@' + s for s in text.split()]) + return _symbols_to_sequence(["@" + s for s in text.split()]) def _should_keep_symbol(s): - return s in _symbol_to_id and s not in ['~', '^', '_'] + return s in _symbol_to_id and s not in ["~", "^", "_"] def _should_keep_phoneme(p): - return p in _phonemes_to_id and p not in ['~', '^', '_'] + return p in _phonemes_to_id and p not in ["~", "^", "_"] diff --git a/TTS/tts/utils/text/abbreviations.py b/TTS/tts/utils/text/abbreviations.py index 579d7dcd..7e44b90c 100644 --- a/TTS/tts/utils/text/abbreviations.py +++ b/TTS/tts/utils/text/abbreviations.py @@ -1,66 +1,73 @@ import re # List of (regular expression, replacement) pairs for abbreviations in english: -abbreviations_en = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) - for x in [ - ('mrs', 'misess'), - ('mr', 'mister'), - ('dr', 'doctor'), - ('st', 'saint'), - ('co', 'company'), - ('jr', 'junior'), - ('maj', 'major'), - ('gen', 'general'), - ('drs', 'doctors'), - ('rev', 'reverend'), - ('lt', 'lieutenant'), - ('hon', 'honorable'), - ('sgt', 'sergeant'), - ('capt', 'captain'), - ('esq', 'esquire'), - ('ltd', 'limited'), - ('col', 'colonel'), - ('ft', 'fort'), - ]] +abbreviations_en = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("mrs", "misess"), + ("mr", "mister"), + ("dr", "doctor"), + ("st", "saint"), + ("co", "company"), + ("jr", "junior"), + ("maj", "major"), + ("gen", "general"), + ("drs", "doctors"), + ("rev", "reverend"), + ("lt", "lieutenant"), + ("hon", "honorable"), + ("sgt", "sergeant"), + ("capt", "captain"), + ("esq", "esquire"), + ("ltd", "limited"), + ("col", "colonel"), + ("ft", "fort"), + ] +] # List of (regular expression, replacement) pairs for abbreviations in french: -abbreviations_fr = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) - for x in [ - ('M', 'monsieur'), - ('Mlle', 'mademoiselle'), - ('Mlles', 'mesdemoiselles'), - ('Mme', 'Madame'), - ('Mmes', 'Mesdames'), - ('N.B', 'nota bene'), - ('M', 'monsieur'), - ('p.c.q', 'parce que'), - ('Pr', 'professeur'), - ('qqch', 'quelque chose'), - ('rdv', 'rendez-vous'), - ('max', 'maximum'), - ('min', 'minimum'), - ('no', 'numéro'), - ('adr', 'adresse'), - ('dr', 'docteur'), - ('st', 'saint'), - ('co', 'companie'), - ('jr', 'junior'), - ('sgt', 'sergent'), - ('capt', 'capitain'), - ('col', 'colonel'), - ('av', 'avenue'), - ('av. J.-C', 'avant Jésus-Christ'), - ('apr. J.-C', 'après Jésus-Christ'), - ('art', 'article'), - ('boul', 'boulevard'), - ('c.-à-d', 'c’est-à-dire'), - ('etc', 'et cetera'), - ('ex', 'exemple'), - ('excl', 'exclusivement'), - ('boul', 'boulevard'), - ]] + [(re.compile('\\b%s' % x[0]), x[1]) for x in [ - ('Mlle', 'mademoiselle'), - ('Mlles', 'mesdemoiselles'), - ('Mme', 'Madame'), - ('Mmes', 'Mesdames'), - ]] +abbreviations_fr = [ + (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) + for x in [ + ("M", "monsieur"), + ("Mlle", "mademoiselle"), + ("Mlles", "mesdemoiselles"), + ("Mme", "Madame"), + ("Mmes", "Mesdames"), + ("N.B", "nota bene"), + ("M", "monsieur"), + ("p.c.q", "parce que"), + ("Pr", "professeur"), + ("qqch", "quelque chose"), + ("rdv", "rendez-vous"), + ("max", "maximum"), + ("min", "minimum"), + ("no", "numéro"), + ("adr", "adresse"), + ("dr", "docteur"), + ("st", "saint"), + ("co", "companie"), + ("jr", "junior"), + ("sgt", "sergent"), + ("capt", "capitain"), + ("col", "colonel"), + ("av", "avenue"), + ("av. J.-C", "avant Jésus-Christ"), + ("apr. J.-C", "après Jésus-Christ"), + ("art", "article"), + ("boul", "boulevard"), + ("c.-à-d", "c’est-à-dire"), + ("etc", "et cetera"), + ("ex", "exemple"), + ("excl", "exclusivement"), + ("boul", "boulevard"), + ] +] + [ + (re.compile("\\b%s" % x[0]), x[1]) + for x in [ + ("Mlle", "mademoiselle"), + ("Mlles", "mesdemoiselles"), + ("Mme", "Madame"), + ("Mmes", "Mesdames"), + ] +] diff --git a/TTS/tts/utils/text/cleaners.py b/TTS/tts/utils/text/cleaners.py index 4e1c6d43..555d01d1 100644 --- a/TTS/tts/utils/text/cleaners.py +++ b/TTS/tts/utils/text/cleaners.py @@ -1,4 +1,4 @@ -''' +""" Cleaners are transformations that run over the input text at both training and eval time. Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" @@ -8,7 +8,7 @@ hyperparameter. Some cleaners are English-specific. You'll typically want to use the Unidecode library (https://pypi.python.org/pypi/Unidecode) 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update the symbols in symbols.py to match your data). -''' +""" import re from unidecode import unidecode @@ -19,13 +19,13 @@ from TTS.tts.utils.chinese_mandarin.numbers import replace_numbers_to_characters # Regular expression matching whitespace: -_whitespace_re = re.compile(r'\s+') +_whitespace_re = re.compile(r"\s+") -def expand_abbreviations(text, lang='en'): - if lang == 'en': +def expand_abbreviations(text, lang="en"): + if lang == "en": _abbreviations = abbreviations_en - elif lang == 'fr': + elif lang == "fr": _abbreviations = abbreviations_fr for regex, replacement in _abbreviations: text = re.sub(regex, replacement, text) @@ -41,7 +41,7 @@ def lowercase(text): def collapse_whitespace(text): - return re.sub(_whitespace_re, ' ', text).strip() + return re.sub(_whitespace_re, " ", text).strip() def convert_to_ascii(text): @@ -49,30 +49,32 @@ def convert_to_ascii(text): def remove_aux_symbols(text): - text = re.sub(r'[\<\>\(\)\[\]\"]+', '', text) + text = re.sub(r"[\<\>\(\)\[\]\"]+", "", text) return text -def replace_symbols(text, lang='en'): - text = text.replace(';', ',') - text = text.replace('-', ' ') - text = text.replace(':', ',') - if lang == 'en': - text = text.replace('&', ' and ') - elif lang == 'fr': - text = text.replace('&', ' et ') - elif lang == 'pt': - text = text.replace('&', ' e ') + +def replace_symbols(text, lang="en"): + text = text.replace(";", ",") + text = text.replace("-", " ") + text = text.replace(":", ",") + if lang == "en": + text = text.replace("&", " and ") + elif lang == "fr": + text = text.replace("&", " et ") + elif lang == "pt": + text = text.replace("&", " e ") return text + def basic_cleaners(text): - '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' + """Basic pipeline that lowercases and collapses whitespace without transliteration.""" text = lowercase(text) text = collapse_whitespace(text) return text def transliteration_cleaners(text): - '''Pipeline for non-English text that transliterates to ASCII.''' + """Pipeline for non-English text that transliterates to ASCII.""" text = convert_to_ascii(text) text = lowercase(text) text = collapse_whitespace(text) @@ -80,7 +82,7 @@ def transliteration_cleaners(text): def basic_german_cleaners(text): - '''Pipeline for German text''' + """Pipeline for German text""" text = lowercase(text) text = collapse_whitespace(text) return text @@ -88,7 +90,7 @@ def basic_german_cleaners(text): # TODO: elaborate it def basic_turkish_cleaners(text): - '''Pipeline for Turkish text''' + """Pipeline for Turkish text""" text = text.replace("I", "ı") text = lowercase(text) text = collapse_whitespace(text) @@ -96,7 +98,7 @@ def basic_turkish_cleaners(text): def english_cleaners(text): - '''Pipeline for English text, including number and abbreviation expansion.''' + """Pipeline for English text, including number and abbreviation expansion.""" text = convert_to_ascii(text) text = lowercase(text) text = expand_time_english(text) @@ -109,33 +111,33 @@ def english_cleaners(text): def french_cleaners(text): - '''Pipeline for French text. There is no need to expand numbers, phonemizer already does that''' - text = expand_abbreviations(text, lang='fr') + """Pipeline for French text. There is no need to expand numbers, phonemizer already does that""" + text = expand_abbreviations(text, lang="fr") text = lowercase(text) - text = replace_symbols(text, lang='fr') + text = replace_symbols(text, lang="fr") text = remove_aux_symbols(text) text = collapse_whitespace(text) return text def portuguese_cleaners(text): - '''Basic pipeline for Portuguese text. There is no need to expand abbreviation and - numbers, phonemizer already does that''' + """Basic pipeline for Portuguese text. There is no need to expand abbreviation and + numbers, phonemizer already does that""" text = lowercase(text) - text = replace_symbols(text, lang='pt') + text = replace_symbols(text, lang="pt") text = remove_aux_symbols(text) text = collapse_whitespace(text) return text def chinese_mandarin_cleaners(text: str) -> str: - '''Basic pipeline for chinese''' + """Basic pipeline for chinese""" text = replace_numbers_to_characters_in_text(text) return text def phoneme_cleaners(text): - '''Pipeline for phonemes mode, including number and abbreviation expansion.''' + """Pipeline for phonemes mode, including number and abbreviation expansion.""" text = expand_numbers(text) text = convert_to_ascii(text) text = expand_abbreviations(text) diff --git a/TTS/tts/utils/text/cmudict.py b/TTS/tts/utils/text/cmudict.py index c0f23406..f206fb04 100644 --- a/TTS/tts/utils/text/cmudict.py +++ b/TTS/tts/utils/text/cmudict.py @@ -3,43 +3,116 @@ import re VALID_SYMBOLS = [ - 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', - 'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', - 'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', - 'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', - 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', - 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', - 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W', - 'Y', 'Z', 'ZH' + "AA", + "AA0", + "AA1", + "AA2", + "AE", + "AE0", + "AE1", + "AE2", + "AH", + "AH0", + "AH1", + "AH2", + "AO", + "AO0", + "AO1", + "AO2", + "AW", + "AW0", + "AW1", + "AW2", + "AY", + "AY0", + "AY1", + "AY2", + "B", + "CH", + "D", + "DH", + "EH", + "EH0", + "EH1", + "EH2", + "ER", + "ER0", + "ER1", + "ER2", + "EY", + "EY0", + "EY1", + "EY2", + "F", + "G", + "HH", + "IH", + "IH0", + "IH1", + "IH2", + "IY", + "IY0", + "IY1", + "IY2", + "JH", + "K", + "L", + "M", + "N", + "NG", + "OW", + "OW0", + "OW1", + "OW2", + "OY", + "OY0", + "OY1", + "OY2", + "P", + "R", + "S", + "SH", + "T", + "TH", + "UH", + "UH0", + "UH1", + "UH2", + "UW", + "UW0", + "UW1", + "UW2", + "V", + "W", + "Y", + "Z", + "ZH", ] class CMUDict: - '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' + """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict""" def __init__(self, file_or_path, keep_ambiguous=True): if isinstance(file_or_path, str): - with open(file_or_path, encoding='latin-1') as f: + with open(file_or_path, encoding="latin-1") as f: entries = _parse_cmudict(f) else: entries = _parse_cmudict(file_or_path) if not keep_ambiguous: - entries = { - word: pron - for word, pron in entries.items() if len(pron) == 1 - } + entries = {word: pron for word, pron in entries.items() if len(pron) == 1} self._entries = entries def __len__(self): return len(self._entries) def lookup(self, word): - '''Returns list of ARPAbet pronunciations of the given word.''' + """Returns list of ARPAbet pronunciations of the given word.""" return self._entries.get(word.upper()) @staticmethod def get_arpabet(word, cmudict, punctuation_symbols): - first_symbol, last_symbol = '', '' + first_symbol, last_symbol = "", "" if word and word[0] in punctuation_symbols: first_symbol = word[0] word = word[1:] @@ -48,19 +121,19 @@ class CMUDict: word = word[:-1] arpabet = cmudict.lookup(word) if arpabet is not None: - return first_symbol + '{%s}' % arpabet[0] + last_symbol + return first_symbol + "{%s}" % arpabet[0] + last_symbol return first_symbol + word + last_symbol -_alt_re = re.compile(r'\([0-9]+\)') +_alt_re = re.compile(r"\([0-9]+\)") def _parse_cmudict(file): cmudict = {} for line in file: - if line and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): - parts = line.split(' ') - word = re.sub(_alt_re, '', parts[0]) + if line and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"): + parts = line.split(" ") + word = re.sub(_alt_re, "", parts[0]) pronunciation = _get_pronunciation(parts[1]) if pronunciation: if word in cmudict: @@ -71,8 +144,8 @@ def _parse_cmudict(file): def _get_pronunciation(s): - parts = s.strip().split(' ') + parts = s.strip().split(" ") for part in parts: if part not in VALID_SYMBOLS: return None - return ' '.join(parts) + return " ".join(parts) diff --git a/TTS/tts/utils/text/number_norm.py b/TTS/tts/utils/text/number_norm.py index 2b83c271..4f648b42 100644 --- a/TTS/tts/utils/text/number_norm.py +++ b/TTS/tts/utils/text/number_norm.py @@ -5,23 +5,23 @@ import re from typing import Dict _inflect = inflect.engine() -_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') -_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') -_currency_re = re.compile(r'(£|\$|¥)([0-9\,\.]*[0-9]+)') -_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') -_number_re = re.compile(r'-?[0-9]+') +_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") +_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") +_currency_re = re.compile(r"(£|\$|¥)([0-9\,\.]*[0-9]+)") +_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") +_number_re = re.compile(r"-?[0-9]+") def _remove_commas(m): - return m.group(1).replace(',', '') + return m.group(1).replace(",", "") def _expand_decimal_point(m): - return m.group(1).replace('.', ' point ') + return m.group(1).replace(".", " point ") def __expand_currency(value: str, inflection: Dict[float, str]) -> str: - parts = value.replace(",", "").split('.') + parts = value.replace(",", "").split(".") if len(parts) > 2: return f"{value} {inflection[2]}" # Unexpected format text = [] @@ -31,7 +31,7 @@ def __expand_currency(value: str, inflection: Dict[float, str]) -> str: text.append(f"{integer} {integer_unit}") fraction = int(parts[1]) if len(parts) > 1 and parts[1] else 0 if fraction > 0: - fraction_unit = inflection.get(fraction/100, inflection[0.02]) + fraction_unit = inflection.get(fraction / 100, inflection[0.02]) text.append(f"{fraction} {fraction_unit}") if len(text) == 0: return f"zero {inflection[2]}" @@ -62,7 +62,7 @@ def _expand_currency(m: "re.Match") -> str: # TODO rin 0.02: "sen", 2: "yen", - } + }, } unit = m.group(1) currency = currencies[unit] @@ -78,16 +78,13 @@ def _expand_number(m): num = int(m.group(0)) if 1000 < num < 3000: if num == 2000: - return 'two thousand' + return "two thousand" if 2000 < num < 2010: - return 'two thousand ' + _inflect.number_to_words(num % 100) + return "two thousand " + _inflect.number_to_words(num % 100) if num % 100 == 0: - return _inflect.number_to_words(num // 100) + ' hundred' - return _inflect.number_to_words(num, - andword='', - zero='oh', - group=2).replace(', ', ' ') - return _inflect.number_to_words(num, andword='') + return _inflect.number_to_words(num // 100) + " hundred" + return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") + return _inflect.number_to_words(num, andword="") def normalize_numbers(text): diff --git a/TTS/tts/utils/text/symbols.py b/TTS/tts/utils/text/symbols.py index 83435917..a531849d 100644 --- a/TTS/tts/utils/text/symbols.py +++ b/TTS/tts/utils/text/symbols.py @@ -1,38 +1,41 @@ # -*- coding: utf-8 -*- -''' +""" Defines the set of symbols used in text input to the model. The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. -''' +""" -def make_symbols(characters, phonemes=None, punctuations='!\'(),-.:;? ', pad='_', eos='~', bos='^'):# pylint: disable=redefined-outer-name - ''' Function to create symbols and phonemes ''' +def make_symbols( + characters, phonemes=None, punctuations="!'(),-.:;? ", pad="_", eos="~", bos="^" +): # pylint: disable=redefined-outer-name + """ Function to create symbols and phonemes """ _symbols = [pad, eos, bos] + list(characters) _phonemes = None if phonemes is not None: _phonemes_sorted = sorted(list(set(phonemes))) # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): - _arpabet = ['@' + s for s in _phonemes_sorted] + _arpabet = ["@" + s for s in _phonemes_sorted] # Export all symbols: _phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations) _symbols += _arpabet return _symbols, _phonemes -_pad = '_' -_eos = '~' -_bos = '^' -_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? ' -_punctuations = '!\'(),-.:;? ' + +_pad = "_" +_eos = "~" +_bos = "^" +_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? " +_punctuations = "!'(),-.:;? " # Phonemes definition (All IPA characters) -_vowels = 'iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ' -_non_pulmonic_consonants = 'ʘɓǀɗǃʄǂɠǁʛ' -_pulmonic_consonants = 'pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ' -_suprasegmentals = 'ˈˌːˑ' -_other_symbols = 'ʍwɥʜʢʡɕʑɺɧʲ' -_diacrilics = 'ɚ˞ɫ' +_vowels = "iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ" +_non_pulmonic_consonants = "ʘɓǀɗǃʄǂɠǁʛ" +_pulmonic_consonants = "pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ" +_suprasegmentals = "ˈˌːˑ" +_other_symbols = "ʍwɥʜʢʡɕʑɺɧʲ" +_diacrilics = "ɚ˞ɫ" _phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _eos, _bos) @@ -43,16 +46,18 @@ symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _e def parse_symbols(): - return {'pad': _pad, - 'eos': _eos, - 'bos': _bos, - 'characters': _characters, - 'punctuations': _punctuations, - 'phonemes': _phonemes} + return { + "pad": _pad, + "eos": _eos, + "bos": _bos, + "characters": _characters, + "punctuations": _punctuations, + "phonemes": _phonemes, + } -if __name__ == '__main__': +if __name__ == "__main__": print(" > TTS symbols {}".format(len(symbols))) print(symbols) print(" > TTS phonemes {}".format(len(phonemes))) - print(''.join(sorted(phonemes))) + print("".join(sorted(phonemes))) diff --git a/TTS/tts/utils/text/time.py b/TTS/tts/utils/text/time.py index 55ecbd8c..a591434f 100644 --- a/TTS/tts/utils/text/time.py +++ b/TTS/tts/utils/text/time.py @@ -3,13 +3,15 @@ import inflect _inflect = inflect.engine() -_time_re = re.compile(r"""\b +_time_re = re.compile( + r"""\b ((0?[0-9])|(1[0-1])|(1[2-9])|(2[0-3])) # hours : ([0-5][0-9]) # minutes \s*(a\\.m\\.|am|pm|p\\.m\\.|a\\.m|p\\.m)? # am/pm \b""", - re.IGNORECASE | re.X) + re.IGNORECASE | re.X, +) def _expand_num(n: int) -> str: diff --git a/TTS/tts/utils/visual.py b/TTS/tts/utils/visual.py index e5bb5891..97a8cd48 100644 --- a/TTS/tts/utils/visual.py +++ b/TTS/tts/utils/visual.py @@ -3,33 +3,25 @@ import matplotlib import numpy as np import torch -matplotlib.use('Agg') +matplotlib.use("Agg") import matplotlib.pyplot as plt from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme -def plot_alignment(alignment, - info=None, - fig_size=(16, 10), - title=None, - output_fig=False): +def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_fig=False): if isinstance(alignment, torch.Tensor): alignment_ = alignment.detach().cpu().numpy().squeeze() else: alignment_ = alignment - alignment_ = alignment_.astype( - np.float32) if alignment_.dtype == np.float16 else alignment_ + alignment_ = alignment_.astype(np.float32) if alignment_.dtype == np.float16 else alignment_ fig, ax = plt.subplots(figsize=fig_size) - im = ax.imshow(alignment_.T, - aspect='auto', - origin='lower', - interpolation='none') + im = ax.imshow(alignment_.T, aspect="auto", origin="lower", interpolation="none") fig.colorbar(im, ax=ax) - xlabel = 'Decoder timestep' + xlabel = "Decoder timestep" if info is not None: - xlabel += '\n\n' + info + xlabel += "\n\n" + info plt.xlabel(xlabel) - plt.ylabel('Encoder timestep') + plt.ylabel("Encoder timestep") # plt.yticks(range(len(text)), list(text)) plt.tight_layout() if title is not None: @@ -39,16 +31,12 @@ def plot_alignment(alignment, return fig -def plot_spectrogram(spectrogram, - ap=None, - fig_size=(16, 10), - output_fig=False): +def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False): if isinstance(spectrogram, torch.Tensor): spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T else: spectrogram_ = spectrogram.T - spectrogram_ = spectrogram_.astype( - np.float32) if spectrogram_.dtype == np.float16 else spectrogram_ + spectrogram_ = spectrogram_.astype(np.float32) if spectrogram_.dtype == np.float16 else spectrogram_ if ap is not None: spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access fig = plt.figure(figsize=fig_size) @@ -60,16 +48,18 @@ def plot_spectrogram(spectrogram, return fig -def visualize(alignment, - postnet_output, - text, - hop_length, - CONFIG, - stop_tokens=None, - decoder_output=None, - output_path=None, - figsize=(8, 24), - output_fig=False): +def visualize( + alignment, + postnet_output, + text, + hop_length, + CONFIG, + stop_tokens=None, + decoder_output=None, + output_path=None, + figsize=(8, 24), + output_fig=False, +): if decoder_output is not None: num_plot = 4 @@ -86,13 +76,13 @@ def visualize(alignment, # compute phoneme representation and back if CONFIG.use_phonemes: seq = phoneme_to_sequence( - text, [CONFIG.text_cleaner], + text, + [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars, - tp=CONFIG.characters if 'characters' in CONFIG.keys() else None) - text = sequence_to_phoneme( - seq, - tp=CONFIG.characters if 'characters' in CONFIG.keys() else None) + tp=CONFIG.characters if "characters" in CONFIG.keys() else None, + ) + text = sequence_to_phoneme(seq, tp=CONFIG.characters if "characters" in CONFIG.keys() else None) print(text) plt.yticks(range(len(text)), list(text)) plt.colorbar() @@ -104,13 +94,15 @@ def visualize(alignment, # plot postnet spectrogram plt.subplot(num_plot, 1, 3) - librosa.display.specshow(postnet_output.T, - sr=CONFIG.audio['sample_rate'], - hop_length=hop_length, - x_axis="time", - y_axis="linear", - fmin=CONFIG.audio['mel_fmin'], - fmax=CONFIG.audio['mel_fmax']) + librosa.display.specshow( + postnet_output.T, + sr=CONFIG.audio["sample_rate"], + hop_length=hop_length, + x_axis="time", + y_axis="linear", + fmin=CONFIG.audio["mel_fmin"], + fmax=CONFIG.audio["mel_fmax"], + ) plt.xlabel("Time", fontsize=label_fontsize) plt.ylabel("Hz", fontsize=label_fontsize) @@ -119,13 +111,15 @@ def visualize(alignment, if decoder_output is not None: plt.subplot(num_plot, 1, 4) - librosa.display.specshow(decoder_output.T, - sr=CONFIG.audio['sample_rate'], - hop_length=hop_length, - x_axis="time", - y_axis="linear", - fmin=CONFIG.audio['mel_fmin'], - fmax=CONFIG.audio['mel_fmax']) + librosa.display.specshow( + decoder_output.T, + sr=CONFIG.audio["sample_rate"], + hop_length=hop_length, + x_axis="time", + y_axis="linear", + fmin=CONFIG.audio["mel_fmin"], + fmax=CONFIG.audio["mel_fmax"], + ) plt.xlabel("Time", fontsize=label_fontsize) plt.ylabel("Hz", fontsize=label_fontsize) plt.tight_layout() diff --git a/TTS/utils/arguments.py b/TTS/utils/arguments.py index 17a76aa6..c688cd16 100644 --- a/TTS/utils/arguments.py +++ b/TTS/utils/arguments.py @@ -29,41 +29,31 @@ def parse_arguments(argv): parser.add_argument( "--continue_path", type=str, - help=("Training output folder to continue training. Used to continue " - "a training. If it is used, 'config_path' is ignored."), + help=( + "Training output folder to continue training. Used to continue " + "a training. If it is used, 'config_path' is ignored." + ), default="", - required="--config_path" not in argv) + required="--config_path" not in argv, + ) parser.add_argument( - "--restore_path", - type=str, - help="Model file to be restored. Use to finetune a model.", - default="") + "--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default="" + ) parser.add_argument( "--best_path", type=str, - help=("Best model file to be used for extracting best loss." - "If not specified, the latest best model in continue path is used"), - default="") - parser.add_argument( - "--config_path", - type=str, - help="Path to config file for training.", - required="--continue_path" not in argv) - parser.add_argument( - "--debug", - type=bool, - default=False, - help="Do not verify commit integrity to run training.") - parser.add_argument( - "--rank", - type=int, - default=0, - help="DISTRIBUTED: process rank for distributed training.") - parser.add_argument( - "--group_id", - type=str, + help=( + "Best model file to be used for extracting best loss." + "If not specified, the latest best model in continue path is used" + ), default="", - help="DISTRIBUTED: process group id.") + ) + parser.add_argument( + "--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in argv + ) + parser.add_argument("--debug", type=bool, default=False, help="Do not verify commit integrity to run training.") + parser.add_argument("--rank", type=int, default=0, help="DISTRIBUTED: process rank for distributed training.") + parser.add_argument("--group_id", type=str, default="", help="DISTRIBUTED: process group id.") return parser.parse_args() @@ -86,7 +76,7 @@ def get_last_checkpoint(path): file_names = glob.glob(os.path.join(path, "*.pth.tar")) last_models = {} last_model_nums = {} - for key in ['checkpoint', 'best_model']: + for key in ["checkpoint", "best_model"]: last_model_num = None last_model = None # pass all the checkpoint files and find @@ -105,7 +95,7 @@ def get_last_checkpoint(path): key_file_names = [fn for fn in file_names if key in fn] if last_model is None and len(key_file_names) > 0: last_model = max(key_file_names, key=os.path.getctime) - last_model_num = torch.load(last_model)['step'] + last_model_num = torch.load(last_model)["step"] if last_model is not None: last_models[key] = last_model @@ -114,16 +104,16 @@ def get_last_checkpoint(path): # check what models were found if not last_models: raise ValueError(f"No models found in continue path {path}!") - if 'checkpoint' not in last_models: # no checkpoint just best model - last_models['checkpoint'] = last_models['best_model'] - elif 'best_model' not in last_models: # no best model + if "checkpoint" not in last_models: # no checkpoint just best model + last_models["checkpoint"] = last_models["best_model"] + elif "best_model" not in last_models: # no best model # this shouldn't happen, but let's handle it just in case - last_models['best_model'] = None + last_models["best_model"] = None # finally check if last best model is more recent than checkpoint - elif last_model_nums['best_model'] > last_model_nums['checkpoint']: - last_models['checkpoint'] = last_models['best_model'] + elif last_model_nums["best_model"] > last_model_nums["checkpoint"]: + last_models["checkpoint"] = last_models["best_model"] - return last_models['checkpoint'], last_models['best_model'] + return last_models["checkpoint"], last_models["best_model"] def process_args(args, model_class): @@ -157,13 +147,12 @@ def process_args(args, model_class): c = load_config(args.config_path) _ = os.path.dirname(os.path.realpath(__file__)) - if 'mixed_precision' in c and c.mixed_precision: + if "mixed_precision" in c and c.mixed_precision: print(" > Mixed precision mode is ON") out_path = args.continue_path if not out_path: - out_path = create_experiment_folder(c.output_path, c.run_name, - args.debug) + out_path = create_experiment_folder(c.output_path, c.run_name, args.debug) audio_path = os.path.join(out_path, "test_audios") @@ -179,11 +168,10 @@ def process_args(args, model_class): # if model characters are not set in the config file # save the default set to the config file for future # compatibility. - if model_class == 'tts' and 'characters' not in c: + if model_class == "tts" and "characters" not in c: used_characters = parse_symbols() - new_fields['characters'] = used_characters - copy_model_files(c, args.config_path, - out_path, new_fields) + new_fields["characters"] = used_characters + copy_model_files(c, args.config_path, out_path, new_fields) os.chmod(audio_path, 0o775) os.chmod(out_path, 0o775) diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index f89d1ee5..2c451d23 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -3,11 +3,12 @@ import soundfile as sf import numpy as np import scipy.io.wavfile import scipy.signal + # import pyworld as pw from TTS.tts.utils.data import StandardScaler -#pylint: disable=too-many-public-methods +# pylint: disable=too-many-public-methods class AudioProcessor(object): """Audio Processor for TTS used by all the data pipelines. @@ -43,35 +44,38 @@ class AudioProcessor(object): stats_path (str, optional): Path to the computed stats file. Defaults to None. verbose (bool, optional): enable/disable logging. Defaults to True. """ - def __init__(self, - sample_rate=None, - resample=False, - num_mels=None, - log_func='np.log10', - min_level_db=None, - frame_shift_ms=None, - frame_length_ms=None, - hop_length=None, - win_length=None, - ref_level_db=None, - fft_size=1024, - power=None, - preemphasis=0.0, - signal_norm=None, - symmetric_norm=None, - max_norm=None, - mel_fmin=None, - mel_fmax=None, - spec_gain=20, - stft_pad_mode='reflect', - clip_norm=True, - griffin_lim_iters=None, - do_trim_silence=False, - trim_db=60, - do_sound_norm=False, - stats_path=None, - verbose=True, - **_): + + def __init__( + self, + sample_rate=None, + resample=False, + num_mels=None, + log_func="np.log10", + min_level_db=None, + frame_shift_ms=None, + frame_length_ms=None, + hop_length=None, + win_length=None, + ref_level_db=None, + fft_size=1024, + power=None, + preemphasis=0.0, + signal_norm=None, + symmetric_norm=None, + max_norm=None, + mel_fmin=None, + mel_fmax=None, + spec_gain=20, + stft_pad_mode="reflect", + clip_norm=True, + griffin_lim_iters=None, + do_trim_silence=False, + trim_db=60, + do_sound_norm=False, + stats_path=None, + verbose=True, + **_, + ): # setup class attributed self.sample_rate = sample_rate @@ -98,14 +102,14 @@ class AudioProcessor(object): self.do_sound_norm = do_sound_norm self.stats_path = stats_path # setup exp_func for db to amp conversion - print(f'self.log_func = {log_func}') - exec(f'self.log_func = {log_func}') #pylint: disable=exec-used - if self.log_func.__name__ == 'log': + print(f"self.log_func = {log_func}") + exec(f"self.log_func = {log_func}") # pylint: disable=exec-used + if self.log_func.__name__ == "log": self.exp_func = np.exp - elif self.log_func.__name__ == 'log10': + elif self.log_func.__name__ == "log10": self.exp_func = lambda x: 10 ** x else: - raise ValueError(' [!] unknown `log_func` value.') + raise ValueError(" [!] unknown `log_func` value.") # setup stft parameters if hop_length is None: # compute stft parameters from given time values @@ -134,17 +138,18 @@ class AudioProcessor(object): self.symmetric_norm = None ### setting up the parameters ### - def _build_mel_basis(self, ): + def _build_mel_basis( + self, + ): if self.mel_fmax is not None: assert self.mel_fmax <= self.sample_rate // 2 return librosa.filters.mel( - self.sample_rate, - self.fft_size, - n_mels=self.num_mels, - fmin=self.mel_fmin, - fmax=self.mel_fmax) + self.sample_rate, self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax + ) - def _stft_parameters(self, ): + def _stft_parameters( + self, + ): """Compute necessary stft parameters with given time values""" factor = self.frame_length_ms / self.frame_shift_ms assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms" @@ -155,24 +160,26 @@ class AudioProcessor(object): ### normalization ### def normalize(self, S): """Put values in [0, self.max_norm] or [-self.max_norm, self.max_norm]""" - #pylint: disable=no-else-return + # pylint: disable=no-else-return S = S.copy() if self.signal_norm: # mean-var scaling - if hasattr(self, 'mel_scaler'): + if hasattr(self, "mel_scaler"): if S.shape[0] == self.num_mels: return self.mel_scaler.transform(S.T).T elif S.shape[0] == self.fft_size / 2: return self.linear_scaler.transform(S.T).T else: - raise RuntimeError(' [!] Mean-Var stats does not match the given feature dimensions.') + raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") # range normalization S -= self.ref_level_db # discard certain range of DB assuming it is air noise - S_norm = ((S - self.min_level_db) / (-self.min_level_db)) + S_norm = (S - self.min_level_db) / (-self.min_level_db) if self.symmetric_norm: S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm if self.clip_norm: - S_norm = np.clip(S_norm, -self.max_norm, self.max_norm) # pylint: disable=invalid-unary-operand-type + S_norm = np.clip( + S_norm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type + ) return S_norm else: S_norm = self.max_norm * S_norm @@ -184,47 +191,49 @@ class AudioProcessor(object): def denormalize(self, S): """denormalize values""" - #pylint: disable=no-else-return + # pylint: disable=no-else-return S_denorm = S.copy() if self.signal_norm: # mean-var scaling - if hasattr(self, 'mel_scaler'): + if hasattr(self, "mel_scaler"): if S_denorm.shape[0] == self.num_mels: return self.mel_scaler.inverse_transform(S_denorm.T).T elif S_denorm.shape[0] == self.fft_size / 2: return self.linear_scaler.inverse_transform(S_denorm.T).T else: - raise RuntimeError(' [!] Mean-Var stats does not match the given feature dimensions.') + raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.") if self.symmetric_norm: if self.clip_norm: - S_denorm = np.clip(S_denorm, -self.max_norm, self.max_norm) #pylint: disable=invalid-unary-operand-type + S_denorm = np.clip( + S_denorm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type + ) S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db return S_denorm + self.ref_level_db else: if self.clip_norm: S_denorm = np.clip(S_denorm, 0, self.max_norm) - S_denorm = (S_denorm * -self.min_level_db / - self.max_norm) + self.min_level_db + S_denorm = (S_denorm * -self.min_level_db / self.max_norm) + self.min_level_db return S_denorm + self.ref_level_db else: return S_denorm ### Mean-STD scaling ### def load_stats(self, stats_path): - stats = np.load(stats_path, allow_pickle=True).item() #pylint: disable=unexpected-keyword-arg - mel_mean = stats['mel_mean'] - mel_std = stats['mel_std'] - linear_mean = stats['linear_mean'] - linear_std = stats['linear_std'] - stats_config = stats['audio_config'] + stats = np.load(stats_path, allow_pickle=True).item() # pylint: disable=unexpected-keyword-arg + mel_mean = stats["mel_mean"] + mel_std = stats["mel_std"] + linear_mean = stats["linear_mean"] + linear_std = stats["linear_std"] + stats_config = stats["audio_config"] # check all audio parameters used for computing stats - skip_parameters = ['griffin_lim_iters', 'stats_path', 'do_trim_silence', 'ref_level_db', 'power'] + skip_parameters = ["griffin_lim_iters", "stats_path", "do_trim_silence", "ref_level_db", "power"] for key in stats_config.keys(): if key in skip_parameters: continue - if key not in ['sample_rate', 'trim_db']: - assert stats_config[key] == self.__dict__[key],\ - f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}" + if key not in ["sample_rate", "trim_db"]: + assert ( + stats_config[key] == self.__dict__[key] + ), f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}" return mel_mean, mel_std, linear_mean, linear_std, stats_config # pylint: disable=attribute-defined-outside-init @@ -243,7 +252,6 @@ class AudioProcessor(object): def _db_to_amp(self, x): return self.exp_func(x / self.spec_gain) - ### Preemphasis ### def apply_preemphasis(self, x): if self.preemphasis == 0: @@ -284,17 +292,17 @@ class AudioProcessor(object): S = self._db_to_amp(S) # Reconstruct phase if self.preemphasis != 0: - return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) - return self._griffin_lim(S**self.power) + return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) + return self._griffin_lim(S ** self.power) def inv_melspectrogram(self, mel_spectrogram): - '''Converts melspectrogram to waveform using librosa''' + """Converts melspectrogram to waveform using librosa""" D = self.denormalize(mel_spectrogram) S = self._db_to_amp(D) S = self._mel_to_linear(S) # Convert back to linear if self.preemphasis != 0: - return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) - return self._griffin_lim(S**self.power) + return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power)) + return self._griffin_lim(S ** self.power) def out_linear_to_mel(self, linear_spec): S = self.denormalize(linear_spec) @@ -315,8 +323,7 @@ class AudioProcessor(object): ) def _istft(self, y): - return librosa.istft( - y, hop_length=self.hop_length, win_length=self.win_length) + return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length) def _griffin_lim(self, S): angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) @@ -328,8 +335,7 @@ class AudioProcessor(object): return y def compute_stft_paddings(self, x, pad_sides=1): - '''compute right padding (final frame) or both sides padding (first and final frames) - ''' + """compute right padding (final frame) or both sides padding (first and final frames)""" assert pad_sides in (1, 2) pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0] if pad_sides == 1: @@ -354,7 +360,7 @@ class AudioProcessor(object): hop_length = int(window_length / 4) threshold = self._db_to_amp(threshold_db) for x in range(hop_length, len(wav) - window_length, hop_length): - if np.max(wav[x:x + window_length]) < threshold: + if np.max(wav[x : x + window_length]) < threshold: return x + hop_length return len(wav) @@ -362,8 +368,9 @@ class AudioProcessor(object): """ Trim silent parts with a threshold and 0.01 sec margin """ margin = int(self.sample_rate * 0.01) wav = wav[margin:-margin] - return librosa.effects.trim( - wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[0] + return librosa.effects.trim(wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[ + 0 + ] @staticmethod def sound_norm(x): @@ -375,14 +382,14 @@ class AudioProcessor(object): x, sr = librosa.load(filename, sr=self.sample_rate) elif sr is None: x, sr = sf.read(filename) - assert self.sample_rate == sr, "%s vs %s"%(self.sample_rate, sr) + assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr) else: x, sr = librosa.load(filename, sr=sr) if self.do_trim_silence: try: x = self.trim_silence(x) except ValueError: - print(f' [!] File cannot be trimmed for silence - {filename}') + print(f" [!] File cannot be trimmed for silence - {filename}") if self.do_sound_norm: x = self.sound_norm(x) return x @@ -396,10 +403,12 @@ class AudioProcessor(object): def mulaw_encode(wav, qc): mu = 2 ** qc - 1 # wav_abs = np.minimum(np.abs(wav), 1.0) - signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1. + mu) + signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu) # Quantize signal to the specified number of levels. signal = (signal + 1) / 2 * mu + 0.5 - return np.floor(signal,) + return np.floor( + signal, + ) @staticmethod def mulaw_decode(wav, qc): @@ -408,15 +417,14 @@ class AudioProcessor(object): x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) return x - @staticmethod def encode_16bits(x): - return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16) + return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16) @staticmethod def quantize(x, bits): - return (x + 1.) * (2**bits - 1) / 2 + return (x + 1.0) * (2 ** bits - 1) / 2 @staticmethod def dequantize(x, bits): - return 2 * x / (2**bits - 1) - 1 + return 2 * x / (2 ** bits - 1) - 1 diff --git a/TTS/utils/console_logger.py b/TTS/utils/console_logger.py index 3affd6af..a035fa4e 100644 --- a/TTS/utils/console_logger.py +++ b/TTS/utils/console_logger.py @@ -2,19 +2,21 @@ import datetime from TTS.utils.io import AttrDict -tcolors = AttrDict({ - 'OKBLUE': '\033[94m', - 'HEADER': '\033[95m', - 'OKGREEN': '\033[92m', - 'WARNING': '\033[93m', - 'FAIL': '\033[91m', - 'ENDC': '\033[0m', - 'BOLD': '\033[1m', - 'UNDERLINE': '\033[4m' -}) +tcolors = AttrDict( + { + "OKBLUE": "\033[94m", + "HEADER": "\033[95m", + "OKGREEN": "\033[92m", + "WARNING": "\033[93m", + "FAIL": "\033[91m", + "ENDC": "\033[0m", + "BOLD": "\033[1m", + "UNDERLINE": "\033[4m", + } +) -class ConsoleLogger(): +class ConsoleLogger: def __init__(self): # TODO: color code for value changes # use these to compare values between iterations @@ -28,23 +30,24 @@ class ConsoleLogger(): return now.strftime("%Y-%m-%d %H:%M:%S") def print_epoch_start(self, epoch, max_epoch): - print("\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, - epoch, max_epoch, tcolors.ENDC), - flush=True) + print( + "\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC), + flush=True, + ) def print_train_start(self): print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}") - def print_train_step(self, batch_steps, step, global_step, log_dict, - loss_dict, avg_loss_dict): + def print_train_step(self, batch_steps, step, global_step, log_dict, loss_dict, avg_loss_dict): indent = " | > " print() log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format( - tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC) + tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC + ) for key, value in loss_dict.items(): # print the avg value if given - if f'avg_{key}' in avg_loss_dict.keys(): - log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}']) + if f"avg_{key}" in avg_loss_dict.keys(): + log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"]) else: log_text += "{}{}: {:.5f} \n".format(indent, key, value) for idx, (key, value) in enumerate(log_dict.items()): @@ -52,13 +55,12 @@ class ConsoleLogger(): log_text += f"{indent}{key}: {value[0]:.{value[1]}f}" else: log_text += f"{indent}{key}: {value}" - if idx < len(log_dict)-1: + if idx < len(log_dict) - 1: log_text += "\n" print(log_text, flush=True) # pylint: disable=unused-argument - def print_train_epoch_end(self, global_step, epoch, epoch_time, - print_dict): + def print_train_epoch_end(self, global_step, epoch, epoch_time, print_dict): indent = " | > " log_text = f"\n{tcolors.BOLD} --> TRAIN PERFORMACE -- EPOCH TIME: {epoch_time:.2f} sec -- GLOBAL_STEP: {global_step}{tcolors.ENDC}\n" for key, value in print_dict.items(): @@ -74,29 +76,28 @@ class ConsoleLogger(): log_text = f"{tcolors.BOLD} --> STEP: {step}{tcolors.ENDC}\n" for key, value in loss_dict.items(): # print the avg value if given - if f'avg_{key}' in avg_loss_dict.keys(): - log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}']) + if f"avg_{key}" in avg_loss_dict.keys(): + log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"]) else: log_text += "{}{}: {:.5f} \n".format(indent, key, value) print(log_text, flush=True) def print_epoch_end(self, epoch, avg_loss_dict): indent = " | > " - log_text = " {}--> EVAL PERFORMANCE{}\n".format( - tcolors.BOLD, tcolors.ENDC) + log_text = " {}--> EVAL PERFORMANCE{}\n".format(tcolors.BOLD, tcolors.ENDC) for key, value in avg_loss_dict.items(): # print the avg value if given - color = '' - sign = '+' + color = "" + sign = "+" diff = 0 if self.old_eval_loss_dict is not None and key in self.old_eval_loss_dict: diff = value - self.old_eval_loss_dict[key] if diff < 0: color = tcolors.OKGREEN - sign = '' + sign = "" elif diff > 0: color = tcolors.FAIL - sign = '+' + sign = "+" log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff) self.old_eval_loss_dict = avg_loss_dict print(log_text, flush=True) diff --git a/TTS/utils/distribute.py b/TTS/utils/distribute.py index 89d4efec..7a1078e8 100644 --- a/TTS/utils/distribute.py +++ b/TTS/utils/distribute.py @@ -14,7 +14,7 @@ class DistributedSampler(Sampler): """ def __init__(self, dataset, num_replicas=None, rank=None): - super(DistributedSampler, self).__init__(dataset) + super().__init__(dataset) if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") @@ -34,11 +34,11 @@ class DistributedSampler(Sampler): indices = torch.arange(len(self.dataset)).tolist() # add extra samples to make it evenly divisible - indices += indices[:(self.total_size - len(indices))] + indices += indices[: (self.total_size - len(indices))] assert len(indices) == self.total_size # subsample - indices = indices[self.rank:self.total_size:self.num_replicas] + indices = indices[self.rank : self.total_size : self.num_replicas] assert len(indices) == self.num_samples return iter(indices) @@ -64,12 +64,7 @@ def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url): torch.cuda.set_device(rank % torch.cuda.device_count()) # Initialize distributed communication - dist.init_process_group( - dist_backend, - init_method=dist_url, - world_size=num_gpus, - rank=rank, - group_name=group_name) + dist.init_process_group(dist_backend, init_method=dist_url, world_size=num_gpus, rank=rank, group_name=group_name) def apply_gradient_allreduce(module): @@ -97,14 +92,13 @@ def apply_gradient_allreduce(module): coalesced = _flatten_dense_tensors(grads) dist.all_reduce(coalesced, op=dist.reduce_op.SUM) coalesced /= dist.get_world_size() - for buf, synced in zip( - grads, _unflatten_dense_tensors(coalesced, grads)): + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): buf.copy_(synced) for param in list(module.parameters()): def allreduce_hook(*_): - Variable._execution_engine.queue_callback(allreduce_params) #pylint: disable=protected-access + Variable._execution_engine.queue_callback(allreduce_params) # pylint: disable=protected-access if param.requires_grad: param.register_hook(allreduce_hook) diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index 60721364..57e22707 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -10,8 +10,7 @@ from pathlib import Path def get_git_branch(): try: out = subprocess.check_output(["git", "branch"]).decode("utf8") - current = next(line for line in out.split("\n") - if line.startswith("*")) + current = next(line for line in out.split("\n") if line.startswith("*")) current.replace("* ", "") except subprocess.CalledProcessError: current = "inside_docker" @@ -29,12 +28,11 @@ def get_commit_hash(): # raise RuntimeError( # " !! Commit before training to get the commit hash.") try: - commit = subprocess.check_output( - ['git', 'rev-parse', '--short', 'HEAD']).decode().strip() + commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip() # Not copying .git folder into docker container except (subprocess.CalledProcessError, FileNotFoundError): commit = "0000000" - print(' > Git Hash: {}'.format(commit)) + print(" > Git Hash: {}".format(commit)) return commit @@ -42,11 +40,10 @@ def create_experiment_folder(root_path, model_name, debug): """ Create a folder with the current date and time """ date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p") if debug: - commit_hash = 'debug' + commit_hash = "debug" else: commit_hash = get_commit_hash() - output_folder = os.path.join( - root_path, model_name + '-' + date_str + '-' + commit_hash) + output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash) os.makedirs(output_folder, exist_ok=True) print(" > Experiment folder: {}".format(output_folder)) return output_folder @@ -72,16 +69,16 @@ def count_parameters(model): def get_user_data_dir(appname): if sys.platform == "win32": import winreg # pylint: disable=import-outside-toplevel + key = winreg.OpenKey( - winreg.HKEY_CURRENT_USER, - r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders" + winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders" ) dir_, _ = winreg.QueryValueEx(key, "Local AppData") ans = Path(dir_).resolve(strict=False) - elif sys.platform == 'darwin': - ans = Path('~/Library/Application Support/').expanduser() + elif sys.platform == "darwin": + ans = Path("~/Library/Application Support/").expanduser() else: - ans = Path.home().joinpath('.local/share') + ans = Path.home().joinpath(".local/share") return ans.joinpath(appname) @@ -91,32 +88,20 @@ def set_init_dict(model_dict, checkpoint_state, c): if k not in model_dict: print(" | > Layer missing in the model definition: {}".format(k)) # 1. filter out unnecessary keys - pretrained_dict = { - k: v - for k, v in checkpoint_state.items() if k in model_dict - } + pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict} # 2. filter out different size layers - pretrained_dict = { - k: v - for k, v in pretrained_dict.items() - if v.numel() == model_dict[k].numel() - } + pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()} # 3. skip reinit layers if c.reinit_layers is not None: for reinit_layer_name in c.reinit_layers: - pretrained_dict = { - k: v - for k, v in pretrained_dict.items() - if reinit_layer_name not in k - } + pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k} # 4. overwrite entries in the existing state dict model_dict.update(pretrained_dict) - print(" | > {} / {} layers are restored.".format(len(pretrained_dict), - len(model_dict))) + print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict))) return model_dict -class KeepAverage(): +class KeepAverage: def __init__(self): self.avg_values = {} self.iters = {} @@ -141,8 +126,7 @@ class KeepAverage(): self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value self.iters[name] += 1 else: - self.avg_values[name] = self.avg_values[name] * \ - self.iters[name] + value + self.avg_values[name] = self.avg_values[name] * self.iters[name] + value self.iters[name] += 1 self.avg_values[name] /= self.iters[name] @@ -155,23 +139,27 @@ class KeepAverage(): self.update_value(key, value) -def check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None): +def check_argument( + name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None +): if alternative in c.keys() and c[alternative] is not None: return if restricted: - assert name in c.keys(), f' [!] {name} not defined in config.json' + assert name in c.keys(), f" [!] {name} not defined in config.json" if name in c.keys(): if max_val: - assert c[name] <= max_val, f' [!] {name} is larger than max value {max_val}' + assert c[name] <= max_val, f" [!] {name} is larger than max value {max_val}" if min_val: - assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}' + assert c[name] >= min_val, f" [!] {name} is smaller than min value {min_val}" if enum_list: - assert c[name].lower() in enum_list, f' [!] {name} is not a valid value' + assert c[name].lower() in enum_list, f" [!] {name} is not a valid value" if isinstance(val_type, list): is_valid = False for typ in val_type: if isinstance(c[name], typ): is_valid = True - assert is_valid or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}' + assert is_valid or c[name] is None, f" [!] {name} has wrong type - {type(c[name])} vs {val_type}" elif val_type: - assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}' + assert ( + isinstance(c[name], val_type) or c[name] is None + ), f" [!] {name} has wrong type - {type(c[name])} vs {val_type}" diff --git a/TTS/utils/io.py b/TTS/utils/io.py index 1703de6f..846c6fc1 100644 --- a/TTS/utils/io.py +++ b/TTS/utils/io.py @@ -8,15 +8,17 @@ from shutil import copyfile class RenamingUnpickler(pickle_tts.Unpickler): """Overload default pickler to solve module renaming problem""" + def find_class(self, module, name): - return super().find_class(module.replace('mozilla_voice_tts', 'TTS'), name) + return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name) class AttrDict(dict): """A custom dict which converts dict keys to class attributes""" + def __init__(self, *args, **kwargs): - super(AttrDict, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.__dict__ = self @@ -25,11 +27,12 @@ def read_json_with_comments(json_path): with open(json_path, "r", encoding="utf-8") as f: input_str = f.read() # handle comments - input_str = re.sub(r'\\\n', '', input_str) - input_str = re.sub(r'//.*\n', '\n', input_str) + input_str = re.sub(r"\\\n", "", input_str) + input_str = re.sub(r"//.*\n", "\n", input_str) data = json.loads(input_str) return data + def load_config(config_path: str) -> AttrDict: """Load config files and discard comments @@ -60,7 +63,7 @@ def copy_model_files(c, config_file, out_path, new_fields): in the config file. """ # copy config.json - copy_config_path = os.path.join(out_path, 'config.json') + copy_config_path = os.path.join(out_path, "config.json") config_lines = open(config_file, "r", encoding="utf-8").readlines() # add extra information fields for key, value in new_fields.items(): @@ -73,7 +76,10 @@ def copy_model_files(c, config_file, out_path, new_fields): config_out_file.writelines(config_lines) config_out_file.close() # copy model stats file if available - if c.audio['stats_path'] is not None: - copy_stats_path = os.path.join(out_path, 'scale_stats.npy') + if c.audio["stats_path"] is not None: + copy_stats_path = os.path.join(out_path, "scale_stats.npy") if not os.path.exists(copy_stats_path): - copyfile(c.audio['stats_path'], copy_stats_path, ) + copyfile( + c.audio["stats_path"], + copy_stats_path, + ) diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index ef77ca4e..ad2dd8b9 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -22,12 +22,13 @@ class ModelManager(object): Args: models_file (str): path to .model.json """ + def __init__(self, models_file=None, output_prefix=None): super().__init__() if output_prefix is None: - self.output_prefix = get_user_data_dir('tts') + self.output_prefix = get_user_data_dir("tts") else: - self.output_prefix = os.path.join(output_prefix, 'tts') + self.output_prefix = os.path.join(output_prefix, "tts") self.url_prefix = "https://drive.google.com/uc?id=" self.models_dict = None if models_file is not None: @@ -72,7 +73,7 @@ class ModelManager(object): print(f" >: {model_type}/{lang}/{dataset}/{model} [already downloaded]") else: print(f" >: {model_type}/{lang}/{dataset}/{model}") - models_name_list.append(f'{model_type}/{lang}/{dataset}/{model}') + models_name_list.append(f"{model_type}/{lang}/{dataset}/{model}") return models_name_list def download_model(self, model_name): @@ -104,25 +105,25 @@ class ModelManager(object): else: os.makedirs(output_path, exist_ok=True) print(f" > Downloading model to {output_path}") - output_stats_path = os.path.join(output_path, 'scale_stats.npy') + output_stats_path = os.path.join(output_path, "scale_stats.npy") # download files to the output path - if self._check_dict_key(model_item, 'github_rls_url'): + if self._check_dict_key(model_item, "github_rls_url"): # download from github release # TODO: pass output_path - self._download_zip_file(model_item['github_rls_url'], output_path) + self._download_zip_file(model_item["github_rls_url"], output_path) else: # download from gdrive - self._download_gdrive_file(model_item['model_file'], output_model_path) - self._download_gdrive_file(model_item['config_file'], output_config_path) - if self._check_dict_key(model_item, 'stats_file'): - self._download_gdrive_file(model_item['stats_file'], output_stats_path) + self._download_gdrive_file(model_item["model_file"], output_model_path) + self._download_gdrive_file(model_item["config_file"], output_config_path) + if self._check_dict_key(model_item, "stats_file"): + self._download_gdrive_file(model_item["stats_file"], output_stats_path) # set the scale_path.npy file path in the model config.json - if self._check_dict_key(model_item, 'stats_file') or os.path.exists(output_stats_path): + if self._check_dict_key(model_item, "stats_file") or os.path.exists(output_stats_path): # set scale stats path in config.json config_path = output_config_path config = load_config(config_path) - config["audio"]['stats_path'] = output_stats_path + config["audio"]["stats_path"] = output_stats_path with open(config_path, "w") as jf: json.dump(config, jf) return output_model_path, output_config_path, model_item diff --git a/TTS/utils/radam.py b/TTS/utils/radam.py index 58cec920..37403929 100644 --- a/TTS/utils/radam.py +++ b/TTS/utils/radam.py @@ -6,7 +6,6 @@ from torch.optim.optimizer import Optimizer class RAdam(Optimizer): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): if lr < 0.0: raise ValueError("Invalid learning rate: {}".format(lr)) @@ -20,13 +19,15 @@ class RAdam(Optimizer): self.degenerated_to_sgd = degenerated_to_sgd if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): for param in params: - if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): - param['buffer'] = [[None, None, None] for _ in range(10)] - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) - super(RAdam, self).__init__(params, defaults) + if "betas" in param and (param["betas"][0] != betas[0] or param["betas"][1] != betas[1]): + param["buffer"] = [[None, None, None] for _ in range(10)] + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)] + ) + super().__init__(params, defaults) def __setstate__(self, state): # pylint: disable=useless-super-delegation - super(RAdam, self).__setstate__(state) + super().__setstate__(state) def step(self, closure=None): @@ -36,62 +37,70 @@ class RAdam(Optimizer): for group in self.param_groups: - for p in group['params']: + for p in group["params"]: if p.grad is None: continue grad = p.grad.data.float() if grad.is_sparse: - raise RuntimeError('RAdam does not support sparse gradients') + raise RuntimeError("RAdam does not support sparse gradients") p_data_fp32 = p.data.float() state = self.state[p] if len(state) == 0: - state['step'] = 0 - state['exp_avg'] = torch.zeros_like(p_data_fp32) - state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p_data_fp32) + state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) else: - state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) - state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32) + state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32) - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - beta1, beta2 = group['betas'] + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = group["betas"] exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - state['step'] += 1 - buffered = group['buffer'][int(state['step'] % 10)] - if state['step'] == buffered[0]: + state["step"] += 1 + buffered = group["buffer"][int(state["step"] % 10)] + if state["step"] == buffered[0]: N_sma, step_size = buffered[1], buffered[2] else: - buffered[0] = state['step'] - beta2_t = beta2 ** state['step'] + buffered[0] = state["step"] + beta2_t = beta2 ** state["step"] N_sma_max = 2 / (1 - beta2) - 1 - N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t) buffered[1] = N_sma # more conservative since it's an approximated value if N_sma >= 5: - step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) + step_size = math.sqrt( + (1 - beta2_t) + * (N_sma - 4) + / (N_sma_max - 4) + * (N_sma - 2) + / N_sma + * N_sma_max + / (N_sma_max - 2) + ) / (1 - beta1 ** state["step"]) elif self.degenerated_to_sgd: - step_size = 1.0 / (1 - beta1 ** state['step']) + step_size = 1.0 / (1 - beta1 ** state["step"]) else: step_size = -1 buffered[2] = step_size # more conservative since it's an approximated value if N_sma >= 5: - if group['weight_decay'] != 0: - p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr']) - denom = exp_avg_sq.sqrt().add_(group['eps']) - p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr']) + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group["lr"]) p.data.copy_(p_data_fp32) elif step_size > 0: - if group['weight_decay'] != 0: - p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr']) - p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr']) + if group["weight_decay"] != 0: + p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"]) + p_data_fp32.add_(exp_avg, alpha=-step_size * group["lr"]) p.data.copy_(p_data_fp32) return loss diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index a7a82d13..b8896ec4 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -9,6 +9,7 @@ from TTS.utils.io import load_config from TTS.tts.utils.generic_utils import setup_model from TTS.tts.utils.speakers import load_speaker_mapping from TTS.vocoder.utils.generic_utils import setup_generator, interpolate_vocoder_input + # pylint: disable=unused-wildcard-import # pylint: disable=wildcard-import from TTS.tts.utils.synthesis import synthesis, trim_silence @@ -49,12 +50,11 @@ class Synthesizer(object): self.use_cuda = use_cuda if self.use_cuda: assert torch.cuda.is_available(), "CUDA is not availabe on this machine." - self.load_tts(tts_checkpoint, tts_config, - use_cuda) - self.output_sample_rate = self.tts_config.audio['sample_rate'] + self.load_tts(tts_checkpoint, tts_config, use_cuda) + self.output_sample_rate = self.tts_config.audio["sample_rate"] if vocoder_checkpoint: self.load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda) - self.output_sample_rate = self.vocoder_config.audio['sample_rate'] + self.output_sample_rate = self.vocoder_config.audio["sample_rate"] @staticmethod def get_segmenter(lang): @@ -69,16 +69,18 @@ class Synthesizer(object): self.num_speakers = 0 # set external speaker embedding if self.tts_config.use_external_speaker_embedding_file: - speaker_embedding = self.tts_speakers[list(self.tts_speakers.keys())[0]]['embedding'] + speaker_embedding = self.tts_speakers[list(self.tts_speakers.keys())[0]]["embedding"] self.speaker_embedding_dim = len(speaker_embedding) def init_speaker(self, speaker_idx): # load speakers speaker_embedding = None - if hasattr(self, 'tts_speakers') and speaker_idx is not None: - assert speaker_idx < len(self.tts_speakers), f" [!] speaker_idx is out of the range. {speaker_idx} vs {len(self.tts_speakers)}" + if hasattr(self, "tts_speakers") and speaker_idx is not None: + assert speaker_idx < len( + self.tts_speakers + ), f" [!] speaker_idx is out of the range. {speaker_idx} vs {len(self.tts_speakers)}" if self.tts_config.use_external_speaker_embedding_file: - speaker_embedding = self.tts_speakers[speaker_idx]['embedding'] + speaker_embedding = self.tts_speakers[speaker_idx]["embedding"] return speaker_embedding def load_tts(self, tts_checkpoint, tts_config, use_cuda): @@ -90,7 +92,7 @@ class Synthesizer(object): self.use_phonemes = self.tts_config.use_phonemes self.ap = AudioProcessor(verbose=False, **self.tts_config.audio) - if 'characters' in self.tts_config.keys(): + if "characters" in self.tts_config.keys(): symbols, phonemes = make_symbols(**self.tts_config.characters) if self.use_phonemes: @@ -105,7 +107,7 @@ class Synthesizer(object): def load_vocoder(self, model_file, model_config, use_cuda): self.vocoder_config = load_config(model_config) - self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config['audio']) + self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config["audio"]) self.vocoder_model = setup_generator(self.vocoder_config) self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True) if use_cuda: @@ -141,7 +143,8 @@ class Synthesizer(object): False, self.tts_config.enable_eos_bos_chars, use_gl, - speaker_embedding=speaker_embedding) + speaker_embedding=speaker_embedding, + ) if not use_gl: # denormalize tts output based on tts audio config mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T @@ -149,7 +152,7 @@ class Synthesizer(object): # renormalize spectrogram based on vocoder config vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T) # compute scale factor for possible sample rate mismatch - scale_factor = [1, self.vocoder_config['audio']['sample_rate'] / self.ap.sample_rate] + scale_factor = [1, self.vocoder_config["audio"]["sample_rate"] / self.ap.sample_rate] if scale_factor[1] != 1: print(" > interpolating tts model output.") vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) @@ -172,7 +175,7 @@ class Synthesizer(object): # compute stats process_time = time.time() - start_time - audio_time = len(wavs) / self.tts_config.audio['sample_rate'] + audio_time = len(wavs) / self.tts_config.audio["sample_rate"] print(f" > Processing time: {process_time}") print(f" > Real-time factor: {process_time / audio_time}") return wavs diff --git a/TTS/utils/tensorboard_logger.py b/TTS/utils/tensorboard_logger.py index 4ee12d74..769d47f5 100644 --- a/TTS/utils/tensorboard_logger.py +++ b/TTS/utils/tensorboard_logger.py @@ -13,40 +13,28 @@ class TensorboardLogger(object): layer_num = 1 for name, param in model.named_parameters(): if param.numel() == 1: - self.writer.add_scalar( - "layer{}-{}/value".format(layer_num, name), - param.max(), step) + self.writer.add_scalar("layer{}-{}/value".format(layer_num, name), param.max(), step) else: - self.writer.add_scalar( - "layer{}-{}/max".format(layer_num, name), - param.max(), step) - self.writer.add_scalar( - "layer{}-{}/min".format(layer_num, name), - param.min(), step) - self.writer.add_scalar( - "layer{}-{}/mean".format(layer_num, name), - param.mean(), step) - self.writer.add_scalar( - "layer{}-{}/std".format(layer_num, name), - param.std(), step) - self.writer.add_histogram( - "layer{}-{}/param".format(layer_num, name), param, step) - self.writer.add_histogram( - "layer{}-{}/grad".format(layer_num, name), param.grad, step) + self.writer.add_scalar("layer{}-{}/max".format(layer_num, name), param.max(), step) + self.writer.add_scalar("layer{}-{}/min".format(layer_num, name), param.min(), step) + self.writer.add_scalar("layer{}-{}/mean".format(layer_num, name), param.mean(), step) + self.writer.add_scalar("layer{}-{}/std".format(layer_num, name), param.std(), step) + self.writer.add_histogram("layer{}-{}/param".format(layer_num, name), param, step) + self.writer.add_histogram("layer{}-{}/grad".format(layer_num, name), param.grad, step) layer_num += 1 def dict_to_tb_scalar(self, scope_name, stats, step): for key, value in stats.items(): - self.writer.add_scalar('{}/{}'.format(scope_name, key), value, step) + self.writer.add_scalar("{}/{}".format(scope_name, key), value, step) def dict_to_tb_figure(self, scope_name, figures, step): for key, value in figures.items(): - self.writer.add_figure('{}/{}'.format(scope_name, key), value, step) + self.writer.add_figure("{}/{}".format(scope_name, key), value, step) def dict_to_tb_audios(self, scope_name, audios, step, sample_rate): for key, value in audios.items(): try: - self.writer.add_audio('{}/{}'.format(scope_name, key), value, step, sample_rate=sample_rate) + self.writer.add_audio("{}/{}".format(scope_name, key), value, step, sample_rate=sample_rate) except RuntimeError: traceback.print_exc() diff --git a/TTS/utils/training.py b/TTS/utils/training.py index 8166562c..56255100 100644 --- a/TTS/utils/training.py +++ b/TTS/utils/training.py @@ -14,12 +14,13 @@ def setup_torch_training_env(cudnn_enable, cudnn_benchmark): def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): - r'''Check model gradient against unexpected jumps and failures''' + r"""Check model gradient against unexpected jumps and failures""" skip_flag = False if ignore_stopnet: if not amp_opt_params: grad_norm = torch.nn.utils.clip_grad_norm_( - [param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip) + [param for name, param in model.named_parameters() if "stopnet" not in name], grad_clip + ) else: grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip) else: @@ -41,11 +42,10 @@ def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None): def lr_decay(init_lr, global_step, warmup_steps): - r'''from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py''' + r"""from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py""" warmup_steps = float(warmup_steps) - step = global_step + 1. - lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5, - step**-0.5) + step = global_step + 1.0 + lr = init_lr * warmup_steps ** 0.5 * np.minimum(step * warmup_steps ** -1.5, step ** -0.5) return lr @@ -54,14 +54,14 @@ def adam_weight_decay(optimizer): Custom weight decay operation, not effecting grad values. """ for group in optimizer.param_groups: - for param in group['params']: - current_lr = group['lr'] - weight_decay = group['weight_decay'] - factor = -weight_decay * group['lr'] - param.data = param.data.add(param.data, - alpha=factor) + for param in group["params"]: + current_lr = group["lr"] + weight_decay = group["weight_decay"] + factor = -weight_decay * group["lr"] + param.data = param.data.add(param.data, alpha=factor) return optimizer, current_lr + # pylint: disable=dangerous-default-value def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}): """ @@ -74,30 +74,23 @@ def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn if not param.requires_grad: continue - if len(param.shape) == 1 or any([skip_name in name for skip_name in skip_list]): + if len(param.shape) == 1 or any((skip_name in name for skip_name in skip_list)): no_decay.append(param) else: decay.append(param) - return [{ - 'params': no_decay, - 'weight_decay': 0. - }, { - 'params': decay, - 'weight_decay': weight_decay - }] + return [{"params": no_decay, "weight_decay": 0.0}, {"params": decay, "weight_decay": weight_decay}] # pylint: disable=protected-access class NoamLR(torch.optim.lr_scheduler._LRScheduler): def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1): self.warmup_steps = float(warmup_steps) - super(NoamLR, self).__init__(optimizer, last_epoch) + super().__init__(optimizer, last_epoch) def get_lr(self): step = max(self.last_epoch, 1) return [ - base_lr * self.warmup_steps**0.5 * - min(step * self.warmup_steps**-1.5, step**-0.5) + base_lr * self.warmup_steps ** 0.5 * min(step * self.warmup_steps ** -1.5, step ** -0.5) for base_lr in self.base_lrs ] diff --git a/TTS/vocoder/datasets/gan_dataset.py b/TTS/vocoder/datasets/gan_dataset.py index 1ab2c974..4010b628 100644 --- a/TTS/vocoder/datasets/gan_dataset.py +++ b/TTS/vocoder/datasets/gan_dataset.py @@ -13,19 +13,22 @@ class GANDataset(Dataset): and converts them to acoustic features on the fly and returns random segments of (audio, feature) couples. """ - def __init__(self, - ap, - items, - seq_len, - hop_len, - pad_short, - conv_pad=2, - is_training=True, - return_segments=True, - use_noise_augment=False, - use_cache=False, - verbose=False): - super(GANDataset, self).__init__() + + def __init__( + self, + ap, + items, + seq_len, + hop_len, + pad_short, + conv_pad=2, + is_training=True, + return_segments=True, + use_noise_augment=False, + use_cache=False, + verbose=False, + ): + super().__init__() self.ap = ap self.item_list = items self.compute_feat = not isinstance(items[0], (tuple, list)) @@ -57,14 +60,14 @@ class GANDataset(Dataset): @staticmethod def find_wav_files(path): - return glob.glob(os.path.join(path, '**', '*.wav'), recursive=True) + return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True) def __len__(self): return len(self.item_list) def __getitem__(self, idx): - """ Return different items for Generator and Discriminator and - cache acoustic features """ + """Return different items for Generator and Discriminator and + cache acoustic features""" if self.return_segments: idx2 = self.G_to_D_mappings[idx] item1 = self.load_item(idx) @@ -76,13 +79,16 @@ class GANDataset(Dataset): def _pad_short_samples(self, audio, mel=None): """Pad samples shorter than the output sequence length""" if len(audio) < self.seq_len: - audio = np.pad(audio, (0, self.seq_len - len(audio)), - mode='constant', - constant_values=0.0) + audio = np.pad(audio, (0, self.seq_len - len(audio)), mode="constant", constant_values=0.0) if mel is not None and mel.shape[1] < self.feat_frame_len: pad_value = self.ap.melspectrogram(np.zeros([self.ap.win_length]))[:, 0] - mel = np.pad(mel, ([0, 0], [0, self.feat_frame_len - mel.shape[1]]), mode='constant', constant_values=pad_value.mean()) + mel = np.pad( + mel, + ([0, 0], [0, self.feat_frame_len - mel.shape[1]]), + mode="constant", + constant_values=pad_value.mean(), + ) return audio, mel def shuffle_mapping(self): @@ -111,12 +117,14 @@ class GANDataset(Dataset): else: audio = self.ap.load_wav(wavpath) mel = np.load(feat_path) - audio, mel= self._pad_short_samples(audio, mel) + audio, mel = self._pad_short_samples(audio, mel) # correct the audio length wrt padding applied in stft audio = np.pad(audio, (0, self.hop_len), mode="edge") - audio = audio[:mel.shape[-1] * self.hop_len] - assert mel.shape[-1] * self.hop_len == audio.shape[-1], f' [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}' + audio = audio[: mel.shape[-1] * self.hop_len] + assert ( + mel.shape[-1] * self.hop_len == audio.shape[-1] + ), f" [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}" audio = torch.from_numpy(audio).float().unsqueeze(0) mel = torch.from_numpy(mel).float().squeeze(0) @@ -128,8 +136,7 @@ class GANDataset(Dataset): mel = mel[:, mel_start:mel_end] audio_start = mel_start * self.hop_len - audio = audio[:, audio_start:audio_start + - self.seq_len] + audio = audio[:, audio_start : audio_start + self.seq_len] if self.use_noise_augment and self.is_training and self.return_segments: audio = audio + (1 / 32768) * torch.randn_like(audio) diff --git a/TTS/vocoder/datasets/preprocess.py b/TTS/vocoder/datasets/preprocess.py index afea45fd..2aa1d116 100644 --- a/TTS/vocoder/datasets/preprocess.py +++ b/TTS/vocoder/datasets/preprocess.py @@ -18,11 +18,7 @@ def preprocess_wav_files(out_path, config, ap): mel = ap.melspectrogram(y) np.save(mel_path, mel) if isinstance(config.mode, int): - quant = ( - ap.mulaw_encode(y, qc=config.mode) - if config.mulaw - else ap.quantize(y, bits=config.mode) - ) + quant = ap.mulaw_encode(y, qc=config.mode) if config.mulaw else ap.quantize(y, bits=config.mode) np.save(quant_path, quant) diff --git a/TTS/vocoder/datasets/wavegrad_dataset.py b/TTS/vocoder/datasets/wavegrad_dataset.py index 6cd5862a..be2b5a8a 100644 --- a/TTS/vocoder/datasets/wavegrad_dataset.py +++ b/TTS/vocoder/datasets/wavegrad_dataset.py @@ -13,18 +13,21 @@ class WaveGradDataset(Dataset): and converts them to acoustic features on the fly and returns random segments of (audio, feature) couples. """ - def __init__(self, - ap, - items, - seq_len, - hop_len, - pad_short, - conv_pad=2, - is_training=True, - return_segments=True, - use_noise_augment=False, - use_cache=False, - verbose=False): + + def __init__( + self, + ap, + items, + seq_len, + hop_len, + pad_short, + conv_pad=2, + is_training=True, + return_segments=True, + use_noise_augment=False, + use_cache=False, + verbose=False, + ): super().__init__() self.ap = ap @@ -54,7 +57,7 @@ class WaveGradDataset(Dataset): @staticmethod def find_wav_files(path): - return glob.glob(os.path.join(path, '**', '*.wav'), recursive=True) + return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True) def __len__(self): return len(self.item_list) @@ -86,13 +89,16 @@ class WaveGradDataset(Dataset): if self.return_segments: # correct audio length wrt segment length if audio.shape[-1] < self.seq_len + self.pad_short: - audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \ - mode='constant', constant_values=0.0) - assert audio.shape[-1] >= self.seq_len + self.pad_short, f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}" + audio = np.pad( + audio, (0, self.seq_len + self.pad_short - len(audio)), mode="constant", constant_values=0.0 + ) + assert ( + audio.shape[-1] >= self.seq_len + self.pad_short + ), f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}" # correct the audio length wrt hop length p = (audio.shape[-1] // self.hop_len + 1) * self.hop_len - audio.shape[-1] - audio = np.pad(audio, (0, p), mode='constant', constant_values=0.0) + audio = np.pad(audio, (0, p), mode="constant", constant_values=0.0) if self.use_cache: self.cache[idx] = audio @@ -126,7 +132,7 @@ class WaveGradDataset(Dataset): for idx, b in enumerate(batch): mel = b[0] audio = b[1] - mels[idx, :, :mel.shape[1]] = mel - audios[idx, :audio.shape[0]] = audio + mels[idx, :, : mel.shape[1]] = mel + audios[idx, : audio.shape[0]] = audio return mels, audios diff --git a/TTS/vocoder/datasets/wavernn_dataset.py b/TTS/vocoder/datasets/wavernn_dataset.py index d45932c9..4a554580 100644 --- a/TTS/vocoder/datasets/wavernn_dataset.py +++ b/TTS/vocoder/datasets/wavernn_dataset.py @@ -9,19 +9,20 @@ class WaveRNNDataset(Dataset): and converts them to acoustic features on the fly. """ - def __init__(self, - ap, - items, - seq_len, - hop_len, - pad, - mode, - mulaw, - is_training=True, - verbose=False, - ): + def __init__( + self, + ap, + items, + seq_len, + hop_len, + pad, + mode, + mulaw, + is_training=True, + verbose=False, + ): - super(WaveRNNDataset, self).__init__() + super().__init__() self.ap = ap self.compute_feat = not isinstance(items[0], (tuple, list)) self.item_list = items @@ -61,8 +62,9 @@ class WaveRNNDataset(Dataset): if self.mode in ["gauss", "mold"]: x_input = audio elif isinstance(self.mode, int): - x_input = (self.ap.mulaw_encode(audio, qc=self.mode) - if self.mulaw else self.ap.quantize(audio, bits=self.mode)) + x_input = ( + self.ap.mulaw_encode(audio, qc=self.mode) if self.mulaw else self.ap.quantize(audio, bits=self.mode) + ) else: raise RuntimeError("Unknown dataset mode - ", self.mode) @@ -71,7 +73,7 @@ class WaveRNNDataset(Dataset): wavpath, feat_path = self.item_list[index] mel = np.load(feat_path.replace("/quant/", "/mel/")) - if mel.shape[-1] < self.mel_len + 2 * self.pad: + if mel.shape[-1] < self.mel_len + 2 * self.pad: print(" [!] Instance is too short! : {}".format(wavpath)) self.item_list[index] = self.item_list[index + 1] feat_path = self.item_list[index] @@ -87,22 +89,14 @@ class WaveRNNDataset(Dataset): def collate(self, batch): mel_win = self.seq_len // self.hop_len + 2 * self.pad - max_offsets = [x[0].shape[-1] - - (mel_win + 2 * self.pad) for x in batch] + max_offsets = [x[0].shape[-1] - (mel_win + 2 * self.pad) for x in batch] mel_offsets = [np.random.randint(0, offset) for offset in max_offsets] - sig_offsets = [(offset + self.pad) * - self.hop_len for offset in mel_offsets] + sig_offsets = [(offset + self.pad) * self.hop_len for offset in mel_offsets] - mels = [ - x[0][:, mel_offsets[i]: mel_offsets[i] + mel_win] - for i, x in enumerate(batch) - ] + mels = [x[0][:, mel_offsets[i] : mel_offsets[i] + mel_win] for i, x in enumerate(batch)] - coarse = [ - x[1][sig_offsets[i]: sig_offsets[i] + self.seq_len + 1] - for i, x in enumerate(batch) - ] + coarse = [x[1][sig_offsets[i] : sig_offsets[i] + self.seq_len + 1] for i, x in enumerate(batch)] mels = np.stack(mels).astype(np.float32) if self.mode in ["gauss", "mold"]: @@ -112,8 +106,7 @@ class WaveRNNDataset(Dataset): elif isinstance(self.mode, int): coarse = np.stack(coarse).astype(np.int64) coarse = torch.LongTensor(coarse) - x_input = (2 * coarse[:, : self.seq_len].float() / - (2 ** self.mode - 1.0) - 1.0) + x_input = 2 * coarse[:, : self.seq_len].float() / (2 ** self.mode - 1.0) - 1.0 y_coarse = coarse[:, 1:] mels = torch.FloatTensor(mels) return x_input, mels, y_coarse diff --git a/TTS/vocoder/layers/losses.py b/TTS/vocoder/layers/losses.py index fab70594..f988565b 100644 --- a/TTS/vocoder/layers/losses.py +++ b/TTS/vocoder/layers/losses.py @@ -6,18 +6,21 @@ from torch.nn import functional as F class TorchSTFT(nn.Module): # pylint: disable=abstract-method """TODO: Merge this with audio.py""" - def __init__(self, - n_fft, - hop_length, - win_length, - pad_wav=False, - window='hann_window', - sample_rate=None, - mel_fmin=0, - mel_fmax=None, - n_mels=80, - use_mel=False): - super(TorchSTFT, self).__init__() + + def __init__( + self, + n_fft, + hop_length, + win_length, + pad_wav=False, + window="hann_window", + sample_rate=None, + mel_fmin=0, + mel_fmax=None, + n_mels=80, + use_mel=False, + ): + super().__init__() self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length @@ -27,8 +30,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method self.mel_fmax = mel_fmax self.n_mels = n_mels self.use_mel = use_mel - self.window = nn.Parameter(getattr(torch, window)(win_length), - requires_grad=False) + self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False) self.mel_basis = None if use_mel: self._build_mel_basis() @@ -49,7 +51,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method x = x.unsqueeze(1) if self.pad_wav: padding = int((self.n_fft - self.hop_length) / 2) - x = torch.nn.functional.pad(x, (padding, padding), mode='reflect') + x = torch.nn.functional.pad(x, (padding, padding), mode="reflect") # B x D x T x 2 o = torch.stft( x.squeeze(1), @@ -61,35 +63,34 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method pad_mode="reflect", # compatible with audio.py normalized=False, onesided=True, - return_complex=False) + return_complex=False, + ) M = o[:, :, :, 0] P = o[:, :, :, 1] - S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8)) + S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8)) if self.use_mel: S = torch.matmul(self.mel_basis.to(x), S) return S def _build_mel_basis(self): - mel_basis = librosa.filters.mel(self.sample_rate, - self.n_fft, - n_mels=self.n_mels, - fmin=self.mel_fmin, - fmax=self.mel_fmax) + mel_basis = librosa.filters.mel( + self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax + ) self.mel_basis = torch.from_numpy(mel_basis).float() - ################################# # GENERATOR LOSSES ################################# class STFTLoss(nn.Module): - """ STFT loss. Input generate and real waveforms are converted + """STFT loss. Input generate and real waveforms are converted to spectrograms compared with L1 and Spectral convergence losses. It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf""" + def __init__(self, n_fft, hop_length, win_length): - super(STFTLoss, self).__init__() + super().__init__() self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length @@ -104,15 +105,14 @@ class STFTLoss(nn.Module): loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro") return loss_mag, loss_sc + class MultiScaleSTFTLoss(torch.nn.Module): - """ Multi-scale STFT loss. Input generate and real waveforms are converted + """Multi-scale STFT loss. Input generate and real waveforms are converted to spectrograms compared with L1 and Spectral convergence losses. It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf""" - def __init__(self, - n_ffts=(1024, 2048, 512), - hop_lengths=(120, 240, 50), - win_lengths=(600, 1200, 240)): - super(MultiScaleSTFTLoss, self).__init__() + + def __init__(self, n_ffts=(1024, 2048, 512), hop_lengths=(120, 240, 50), win_lengths=(600, 1200, 240)): + super().__init__() self.loss_funcs = torch.nn.ModuleList() for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths): self.loss_funcs.append(STFTLoss(n_fft, hop_length, win_length)) @@ -129,19 +129,25 @@ class MultiScaleSTFTLoss(torch.nn.Module): loss_mag /= N return loss_mag, loss_sc + class L1SpecLoss(nn.Module): """ L1 Loss over Spectrograms as described in HiFiGAN paper https://arxiv.org/pdf/2010.05646.pdf""" - def __init__(self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True): + + def __init__( + self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True + ): super().__init__() self.use_mel = use_mel - self.stft = TorchSTFT(n_fft, - hop_length, - win_length, - sample_rate=sample_rate, - mel_fmin=mel_fmin, - mel_fmax=mel_fmax, - n_mels=n_mels, - use_mel=use_mel) + self.stft = TorchSTFT( + n_fft, + hop_length, + win_length, + sample_rate=sample_rate, + mel_fmin=mel_fmin, + mel_fmax=mel_fmax, + n_mels=n_mels, + use_mel=use_mel, + ) def forward(self, y_hat, y): y_hat_M = self.stft(y_hat) @@ -150,9 +156,11 @@ class L1SpecLoss(nn.Module): loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M)) return loss_mag + class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss): - """ Multiscale STFT loss for multi band model outputs. + """Multiscale STFT loss for multi band model outputs. From MultiBand-MelGAN paper https://arxiv.org/abs/2005.05106""" + # pylint: disable=no-self-use def forward(self, y_hat, y): y_hat = y_hat.view(-1, 1, y_hat.shape[2]) @@ -162,6 +170,7 @@ class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss): class MSEGLoss(nn.Module): """ Mean Squared Generator Loss """ + # pylint: disable=no-self-use def forward(self, score_real): loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape)) @@ -170,10 +179,11 @@ class MSEGLoss(nn.Module): class HingeGLoss(nn.Module): """ Hinge Discriminator Loss """ + # pylint: disable=no-self-use def forward(self, score_real): # TODO: this might be wrong - loss_fake = torch.mean(F.relu(1. - score_real)) + loss_fake = torch.mean(F.relu(1.0 - score_real)) return loss_fake @@ -184,8 +194,11 @@ class HingeGLoss(nn.Module): class MSEDLoss(nn.Module): """ Mean Squared Discriminator Loss """ - def __init__(self,): - super(MSEDLoss, self).__init__() + + def __init__( + self, + ): + super().__init__() self.loss_func = nn.MSELoss() # pylint: disable=no-self-use @@ -198,17 +211,20 @@ class MSEDLoss(nn.Module): class HingeDLoss(nn.Module): """ Hinge Discriminator Loss """ + # pylint: disable=no-self-use def forward(self, score_fake, score_real): - loss_real = torch.mean(F.relu(1. - score_real)) - loss_fake = torch.mean(F.relu(1. + score_fake)) + loss_real = torch.mean(F.relu(1.0 - score_real)) + loss_fake = torch.mean(F.relu(1.0 + score_fake)) loss_d = loss_real + loss_fake return loss_d, loss_real, loss_fake class MelganFeatureLoss(nn.Module): - def __init__(self,): - super(MelganFeatureLoss, self).__init__() + def __init__( + self, + ): + super().__init__() self.loss_func = nn.L1Loss() # pylint: disable=no-self-use @@ -229,8 +245,8 @@ class MelganFeatureLoss(nn.Module): def _apply_G_adv_loss(scores_fake, loss_func): - """ Compute G adversarial loss function - and normalize values """ + """Compute G adversarial loss function + and normalize values""" adv_loss = 0 if isinstance(scores_fake, list): for score_fake in scores_fake: @@ -279,24 +295,26 @@ class GeneratorLoss(nn.Module): Args: C (AttrDict): model configuration. """ + def __init__(self, C): super().__init__() - assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\ - " [!] Cannot use HingeGANLoss and MSEGANLoss together." + assert not ( + C.use_mse_gan_loss and C.use_hinge_gan_loss + ), " [!] Cannot use HingeGANLoss and MSEGANLoss together." - self.use_stft_loss = C.use_stft_loss if 'use_stft_loss' in C else False - self.use_subband_stft_loss = C.use_subband_stft_loss if 'use_subband_stft_loss' in C else False - self.use_mse_gan_loss = C.use_mse_gan_loss if 'use_mse_gan_loss' in C else False - self.use_hinge_gan_loss = C.use_hinge_gan_loss if 'use_hinge_gan_loss' in C else False - self.use_feat_match_loss = C.use_feat_match_loss if 'use_feat_match_loss' in C else False - self.use_l1_spec_loss = C.use_l1_spec_loss if 'use_l1_spec_loss' in C else False + self.use_stft_loss = C.use_stft_loss if "use_stft_loss" in C else False + self.use_subband_stft_loss = C.use_subband_stft_loss if "use_subband_stft_loss" in C else False + self.use_mse_gan_loss = C.use_mse_gan_loss if "use_mse_gan_loss" in C else False + self.use_hinge_gan_loss = C.use_hinge_gan_loss if "use_hinge_gan_loss" in C else False + self.use_feat_match_loss = C.use_feat_match_loss if "use_feat_match_loss" in C else False + self.use_l1_spec_loss = C.use_l1_spec_loss if "use_l1_spec_loss" in C else False - self.stft_loss_weight = C.stft_loss_weight if 'stft_loss_weight' in C else 0.0 - self.subband_stft_loss_weight = C.subband_stft_loss_weight if 'subband_stft_loss_weight' in C else 0.0 - self.mse_gan_loss_weight = C.mse_G_loss_weight if 'mse_G_loss_weight' in C else 0.0 - self.hinge_gan_loss_weight = C.hinge_G_loss_weight if 'hinde_G_loss_weight' in C else 0.0 - self.feat_match_loss_weight = C.feat_match_loss_weight if 'feat_match_loss_weight' in C else 0.0 - self.l1_spec_loss_weight = C.l1_spec_loss_weight if 'l1_spec_loss_weight' in C else 0.0 + self.stft_loss_weight = C.stft_loss_weight if "stft_loss_weight" in C else 0.0 + self.subband_stft_loss_weight = C.subband_stft_loss_weight if "subband_stft_loss_weight" in C else 0.0 + self.mse_gan_loss_weight = C.mse_G_loss_weight if "mse_G_loss_weight" in C else 0.0 + self.hinge_gan_loss_weight = C.hinge_G_loss_weight if "hinde_G_loss_weight" in C else 0.0 + self.feat_match_loss_weight = C.feat_match_loss_weight if "feat_match_loss_weight" in C else 0.0 + self.l1_spec_loss_weight = C.l1_spec_loss_weight if "l1_spec_loss_weight" in C else 0.0 if C.use_stft_loss: self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params) @@ -309,63 +327,67 @@ class GeneratorLoss(nn.Module): if C.use_feat_match_loss: self.feat_match_loss = MelganFeatureLoss() if C.use_l1_spec_loss: - assert C.audio['sample_rate'] == C.l1_spec_loss_params['sample_rate'] + assert C.audio["sample_rate"] == C.l1_spec_loss_params["sample_rate"] self.l1_spec_loss = L1SpecLoss(**C.l1_spec_loss_params) - def forward(self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None): + def forward( + self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None + ): gen_loss = 0 adv_loss = 0 return_dict = {} # STFT Loss if self.use_stft_loss: - stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1)) - return_dict['G_stft_loss_mg'] = stft_loss_mg - return_dict['G_stft_loss_sc'] = stft_loss_sc + stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, : y.size(2)].squeeze(1), y.squeeze(1)) + return_dict["G_stft_loss_mg"] = stft_loss_mg + return_dict["G_stft_loss_sc"] = stft_loss_sc gen_loss = gen_loss + self.stft_loss_weight * (stft_loss_mg + stft_loss_sc) # L1 Spec loss if self.use_l1_spec_loss: l1_spec_loss = self.l1_spec_loss(y_hat, y) - return_dict['G_l1_spec_loss'] = l1_spec_loss + return_dict["G_l1_spec_loss"] = l1_spec_loss gen_loss = gen_loss + self.l1_spec_loss_weight * l1_spec_loss # subband STFT Loss if self.use_subband_stft_loss: subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub) - return_dict['G_subband_stft_loss_mg'] = subband_stft_loss_mg - return_dict['G_subband_stft_loss_sc'] = subband_stft_loss_sc + return_dict["G_subband_stft_loss_mg"] = subband_stft_loss_mg + return_dict["G_subband_stft_loss_sc"] = subband_stft_loss_sc gen_loss = gen_loss + self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc) # multiscale MSE adversarial loss if self.use_mse_gan_loss and scores_fake is not None: mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss) - return_dict['G_mse_fake_loss'] = mse_fake_loss + return_dict["G_mse_fake_loss"] = mse_fake_loss adv_loss = adv_loss + self.mse_gan_loss_weight * mse_fake_loss # multiscale Hinge adversarial loss if self.use_hinge_gan_loss and not scores_fake is not None: hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss) - return_dict['G_hinge_fake_loss'] = hinge_fake_loss + return_dict["G_hinge_fake_loss"] = hinge_fake_loss adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss # Feature Matching Loss if self.use_feat_match_loss and not feats_fake is None: feat_match_loss = self.feat_match_loss(feats_fake, feats_real) - return_dict['G_feat_match_loss'] = feat_match_loss + return_dict["G_feat_match_loss"] = feat_match_loss adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss - return_dict['G_loss'] = gen_loss + adv_loss - return_dict['G_gen_loss'] = gen_loss - return_dict['G_adv_loss'] = adv_loss + return_dict["G_loss"] = gen_loss + adv_loss + return_dict["G_gen_loss"] = gen_loss + return_dict["G_adv_loss"] = adv_loss return return_dict class DiscriminatorLoss(nn.Module): """Like ```GeneratorLoss```""" + def __init__(self, C): super().__init__() - assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\ - " [!] Cannot use HingeGANLoss and MSEGANLoss together." + assert not ( + C.use_mse_gan_loss and C.use_hinge_gan_loss + ), " [!] Cannot use HingeGANLoss and MSEGANLoss together." self.use_mse_gan_loss = C.use_mse_gan_loss self.use_hinge_gan_loss = C.use_hinge_gan_loss @@ -381,23 +403,21 @@ class DiscriminatorLoss(nn.Module): if self.use_mse_gan_loss: mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss( - scores_fake=scores_fake, - scores_real=scores_real, - loss_func=self.mse_loss) - return_dict['D_mse_gan_loss'] = mse_D_loss - return_dict['D_mse_gan_real_loss'] = mse_D_real_loss - return_dict['D_mse_gan_fake_loss'] = mse_D_fake_loss + scores_fake=scores_fake, scores_real=scores_real, loss_func=self.mse_loss + ) + return_dict["D_mse_gan_loss"] = mse_D_loss + return_dict["D_mse_gan_real_loss"] = mse_D_real_loss + return_dict["D_mse_gan_fake_loss"] = mse_D_fake_loss loss += mse_D_loss if self.use_hinge_gan_loss: hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss( - scores_fake=scores_fake, - scores_real=scores_real, - loss_func=self.hinge_loss) - return_dict['D_hinge_gan_loss'] = hinge_D_loss - return_dict['D_hinge_gan_real_loss'] = hinge_D_real_loss - return_dict['D_hinge_gan_fake_loss'] = hinge_D_fake_loss + scores_fake=scores_fake, scores_real=scores_real, loss_func=self.hinge_loss + ) + return_dict["D_hinge_gan_loss"] = hinge_D_loss + return_dict["D_hinge_gan_real_loss"] = hinge_D_real_loss + return_dict["D_hinge_gan_fake_loss"] = hinge_D_fake_loss loss += hinge_D_loss - return_dict['D_loss'] = loss + return_dict["D_loss"] = loss return return_dict diff --git a/TTS/vocoder/layers/melgan.py b/TTS/vocoder/layers/melgan.py index 58c12a2e..7fd999d9 100644 --- a/TTS/vocoder/layers/melgan.py +++ b/TTS/vocoder/layers/melgan.py @@ -4,7 +4,7 @@ from torch.nn.utils import weight_norm class ResidualStack(nn.Module): def __init__(self, channels, num_res_blocks, kernel_size): - super(ResidualStack, self).__init__() + super().__init__() assert (kernel_size - 1) % 2 == 0, " [!] kernel_size has to be odd." base_padding = (kernel_size - 1) // 2 @@ -12,26 +12,23 @@ class ResidualStack(nn.Module): self.blocks = nn.ModuleList() for idx in range(num_res_blocks): layer_kernel_size = kernel_size - layer_dilation = layer_kernel_size**idx + layer_dilation = layer_kernel_size ** idx layer_padding = base_padding * layer_dilation - self.blocks += [nn.Sequential( - nn.LeakyReLU(0.2), - nn.ReflectionPad1d(layer_padding), - weight_norm( - nn.Conv1d(channels, - channels, - kernel_size=kernel_size, - dilation=layer_dilation, - bias=True)), - nn.LeakyReLU(0.2), - weight_norm( - nn.Conv1d(channels, channels, kernel_size=1, bias=True)), - )] + self.blocks += [ + nn.Sequential( + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(layer_padding), + weight_norm( + nn.Conv1d(channels, channels, kernel_size=kernel_size, dilation=layer_dilation, bias=True) + ), + nn.LeakyReLU(0.2), + weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)), + ) + ] - self.shortcuts = nn.ModuleList([ - weight_norm(nn.Conv1d(channels, channels, kernel_size=1, - bias=True)) for i in range(num_res_blocks) - ]) + self.shortcuts = nn.ModuleList( + [weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for i in range(num_res_blocks)] + ) def forward(self, x): for block, shortcut in zip(self.blocks, self.shortcuts): diff --git a/TTS/vocoder/layers/parallel_wavegan.py b/TTS/vocoder/layers/parallel_wavegan.py index bedfe551..889e8aa6 100644 --- a/TTS/vocoder/layers/parallel_wavegan.py +++ b/TTS/vocoder/layers/parallel_wavegan.py @@ -4,54 +4,44 @@ from torch.nn import functional as F class ResidualBlock(torch.nn.Module): """Residual block module in WaveNet.""" - def __init__(self, - kernel_size=3, - res_channels=64, - gate_channels=128, - skip_channels=64, - aux_channels=80, - dropout=0.0, - dilation=1, - bias=True, - use_causal_conv=False): - super(ResidualBlock, self).__init__() + + def __init__( + self, + kernel_size=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=80, + dropout=0.0, + dilation=1, + bias=True, + use_causal_conv=False, + ): + super().__init__() self.dropout = dropout # no future time stamps available if use_causal_conv: padding = (kernel_size - 1) * dilation else: - assert (kernel_size - - 1) % 2 == 0, "Not support even number kernel size." + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." padding = (kernel_size - 1) // 2 * dilation self.use_causal_conv = use_causal_conv # dilation conv - self.conv = torch.nn.Conv1d(res_channels, - gate_channels, - kernel_size, - padding=padding, - dilation=dilation, - bias=bias) + self.conv = torch.nn.Conv1d( + res_channels, gate_channels, kernel_size, padding=padding, dilation=dilation, bias=bias + ) # local conditioning if aux_channels > 0: - self.conv1x1_aux = torch.nn.Conv1d(aux_channels, - gate_channels, - 1, - bias=False) + self.conv1x1_aux = torch.nn.Conv1d(aux_channels, gate_channels, 1, bias=False) else: self.conv1x1_aux = None # conv output is split into two groups gate_out_channels = gate_channels // 2 - self.conv1x1_out = torch.nn.Conv1d(gate_out_channels, - res_channels, - 1, - bias=bias) - self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels, - skip_channels, - 1, - bias=bias) + self.conv1x1_out = torch.nn.Conv1d(gate_out_channels, res_channels, 1, bias=bias) + self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels, skip_channels, 1, bias=bias) def forward(self, x, c): """ @@ -63,7 +53,7 @@ class ResidualBlock(torch.nn.Module): x = self.conv(x) # remove future time steps if use_causal_conv conv - x = x[:, :, :residual.size(-1)] if self.use_causal_conv else x + x = x[:, :, : residual.size(-1)] if self.use_causal_conv else x # split into two part for gated activation splitdim = 1 @@ -82,6 +72,6 @@ class ResidualBlock(torch.nn.Module): s = self.conv1x1_skip(x) # for residual connection - x = (self.conv1x1_out(x) + residual) * (0.5**2) + x = (self.conv1x1_out(x) + residual) * (0.5 ** 2) return x, s diff --git a/TTS/vocoder/layers/pqmf.py b/TTS/vocoder/layers/pqmf.py index d31953d6..5cfbf729 100644 --- a/TTS/vocoder/layers/pqmf.py +++ b/TTS/vocoder/layers/pqmf.py @@ -9,21 +9,21 @@ from scipy import signal as sig # https://github.com/kan-bayashi/ParallelWaveGAN/tree/master/parallel_wavegan class PQMF(torch.nn.Module): def __init__(self, N=4, taps=62, cutoff=0.15, beta=9.0): - super(PQMF, self).__init__() + super().__init__() self.N = N self.taps = taps self.cutoff = cutoff self.beta = beta - QMF = sig.firwin(taps + 1, cutoff, window=('kaiser', beta)) + QMF = sig.firwin(taps + 1, cutoff, window=("kaiser", beta)) H = np.zeros((N, len(QMF))) G = np.zeros((N, len(QMF))) for k in range(N): - constant_factor = (2 * k + 1) * (np.pi / - (2 * N)) * (np.arange(taps + 1) - - ((taps - 1) / 2)) # TODO: (taps - 1) -> taps - phase = (-1)**k * np.pi / 4 + constant_factor = ( + (2 * k + 1) * (np.pi / (2 * N)) * (np.arange(taps + 1) - ((taps - 1) / 2)) + ) # TODO: (taps - 1) -> taps + phase = (-1) ** k * np.pi / 4 H[k] = 2 * QMF * np.cos(constant_factor + phase) G[k] = 2 * QMF * np.cos(constant_factor - phase) @@ -49,8 +49,6 @@ class PQMF(torch.nn.Module): return F.conv1d(x, self.H, padding=self.taps // 2, stride=self.N) def synthesis(self, x): - x = F.conv_transpose1d(x, - self.updown_filter * self.N, - stride=self.N) + x = F.conv_transpose1d(x, self.updown_filter * self.N, stride=self.N) x = F.conv1d(x, self.G, padding=self.taps // 2) return x diff --git a/TTS/vocoder/layers/upsample.py b/TTS/vocoder/layers/upsample.py index 13406875..e169db00 100644 --- a/TTS/vocoder/layers/upsample.py +++ b/TTS/vocoder/layers/upsample.py @@ -4,31 +4,31 @@ from torch.nn import functional as F class Stretch2d(torch.nn.Module): def __init__(self, x_scale, y_scale, mode="nearest"): - super(Stretch2d, self).__init__() + super().__init__() self.x_scale = x_scale self.y_scale = y_scale self.mode = mode def forward(self, x): """ - x (Tensor): Input tensor (B, C, F, T). - Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale), + x (Tensor): Input tensor (B, C, F, T). + Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale), """ - return F.interpolate( - x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode) + return F.interpolate(x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode) class UpsampleNetwork(torch.nn.Module): # pylint: disable=dangerous-default-value - def __init__(self, - upsample_factors, - nonlinear_activation=None, - nonlinear_activation_params={}, - interpolate_mode="nearest", - freq_axis_kernel_size=1, - use_causal_conv=False, - ): - super(UpsampleNetwork, self).__init__() + def __init__( + self, + upsample_factors, + nonlinear_activation=None, + nonlinear_activation_params={}, + interpolate_mode="nearest", + freq_axis_kernel_size=1, + use_causal_conv=False, + ): + super().__init__() self.use_causal_conv = use_causal_conv self.up_layers = torch.nn.ModuleList() for scale in upsample_factors: @@ -54,8 +54,8 @@ class UpsampleNetwork(torch.nn.Module): def forward(self, c): """ - c : (B, C, T_in). - Tensor: (B, C, T_upsample) + c : (B, C, T_in). + Tensor: (B, C, T_upsample) """ c = c.unsqueeze(1) # (B, 1, C, T) for f in self.up_layers: @@ -65,17 +65,18 @@ class UpsampleNetwork(torch.nn.Module): class ConvUpsample(torch.nn.Module): # pylint: disable=dangerous-default-value - def __init__(self, - upsample_factors, - nonlinear_activation=None, - nonlinear_activation_params={}, - interpolate_mode="nearest", - freq_axis_kernel_size=1, - aux_channels=80, - aux_context_window=0, - use_causal_conv=False - ): - super(ConvUpsample, self).__init__() + def __init__( + self, + upsample_factors, + nonlinear_activation=None, + nonlinear_activation_params={}, + interpolate_mode="nearest", + freq_axis_kernel_size=1, + aux_channels=80, + aux_context_window=0, + use_causal_conv=False, + ): + super().__init__() self.aux_context_window = aux_context_window self.use_causal_conv = use_causal_conv and aux_context_window > 0 # To capture wide-context information in conditional features @@ -97,5 +98,5 @@ class ConvUpsample(torch.nn.Module): Tensor: (B, C, T_upsampled), """ c_ = self.conv_in(c) - c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_ + c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_ return self.upsample(c) diff --git a/TTS/vocoder/layers/wavegrad.py b/TTS/vocoder/layers/wavegrad.py index 81f03124..83cd4233 100644 --- a/TTS/vocoder/layers/wavegrad.py +++ b/TTS/vocoder/layers/wavegrad.py @@ -13,6 +13,7 @@ class Conv1d(nn.Conv1d): class PositionalEncoding(nn.Module): """Positional encoding with noise level conditioning""" + def __init__(self, n_channels, max_len=10000): super().__init__() self.n_channels = n_channels @@ -23,9 +24,7 @@ class PositionalEncoding(nn.Module): def forward(self, x, noise_level): if x.shape[2] > self.pe.shape[1]: self.init_pe_matrix(x.shape[1], x.shape[2], x) - return x + noise_level[..., None, - None] + self.pe[:, :x.size(2)].repeat( - x.shape[0], 1, 1) / self.C + return x + noise_level[..., None, None] + self.pe[:, : x.size(2)].repeat(x.shape[0], 1, 1) / self.C def init_pe_matrix(self, n_channels, max_len, x): pe = torch.zeros(max_len, n_channels) @@ -79,30 +78,18 @@ class UBlock(nn.Module): self.factor = factor self.res_block = Conv1d(input_size, hidden_size, 1) - self.main_block = nn.ModuleList([ - Conv1d(input_size, - hidden_size, - 3, - dilation=dilation[0], - padding=dilation[0]), - Conv1d(hidden_size, - hidden_size, - 3, - dilation=dilation[1], - padding=dilation[1]) - ]) - self.out_block = nn.ModuleList([ - Conv1d(hidden_size, - hidden_size, - 3, - dilation=dilation[2], - padding=dilation[2]), - Conv1d(hidden_size, - hidden_size, - 3, - dilation=dilation[3], - padding=dilation[3]) - ]) + self.main_block = nn.ModuleList( + [ + Conv1d(input_size, hidden_size, 3, dilation=dilation[0], padding=dilation[0]), + Conv1d(hidden_size, hidden_size, 3, dilation=dilation[1], padding=dilation[1]), + ] + ) + self.out_block = nn.ModuleList( + [ + Conv1d(hidden_size, hidden_size, 3, dilation=dilation[2], padding=dilation[2]), + Conv1d(hidden_size, hidden_size, 3, dilation=dilation[3], padding=dilation[3]), + ] + ) def forward(self, x, shift, scale): x_inter = F.interpolate(x, size=x.shape[-1] * self.factor) @@ -147,11 +134,13 @@ class DBlock(nn.Module): super().__init__() self.factor = factor self.res_block = Conv1d(input_size, hidden_size, 1) - self.main_block = nn.ModuleList([ - Conv1d(input_size, hidden_size, 3, dilation=1, padding=1), - Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2), - Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4), - ]) + self.main_block = nn.ModuleList( + [ + Conv1d(input_size, hidden_size, 3, dilation=1, padding=1), + Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2), + Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4), + ] + ) def forward(self, x): size = x.shape[-1] // self.factor diff --git a/TTS/vocoder/models/fullband_melgan_generator.py b/TTS/vocoder/models/fullband_melgan_generator.py index 52dcc75e..ee25559a 100644 --- a/TTS/vocoder/models/fullband_melgan_generator.py +++ b/TTS/vocoder/models/fullband_melgan_generator.py @@ -4,27 +4,30 @@ from TTS.vocoder.models.melgan_generator import MelganGenerator class FullbandMelganGenerator(MelganGenerator): - def __init__(self, - in_channels=80, - out_channels=1, - proj_kernel=7, - base_channels=512, - upsample_factors=(2, 8, 2, 2), - res_kernel=3, - num_res_blocks=4): - super().__init__(in_channels=in_channels, - out_channels=out_channels, - proj_kernel=proj_kernel, - base_channels=base_channels, - upsample_factors=upsample_factors, - res_kernel=res_kernel, - num_res_blocks=num_res_blocks) + def __init__( + self, + in_channels=80, + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=(2, 8, 2, 2), + res_kernel=3, + num_res_blocks=4, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + proj_kernel=proj_kernel, + base_channels=base_channels, + upsample_factors=upsample_factors, + res_kernel=res_kernel, + num_res_blocks=num_res_blocks, + ) @torch.no_grad() def inference(self, cond_features): cond_features = cond_features.to(self.layers[1].weight.device) cond_features = torch.nn.functional.pad( - cond_features, - (self.inference_padding, self.inference_padding), - 'replicate') + cond_features, (self.inference_padding, self.inference_padding), "replicate" + ) return self.layers(cond_features) diff --git a/TTS/vocoder/models/melgan_discriminator.py b/TTS/vocoder/models/melgan_discriminator.py index 3847babb..fcc43665 100644 --- a/TTS/vocoder/models/melgan_discriminator.py +++ b/TTS/vocoder/models/melgan_discriminator.py @@ -4,14 +4,16 @@ from torch.nn.utils import weight_norm class MelganDiscriminator(nn.Module): - def __init__(self, - in_channels=1, - out_channels=1, - kernel_sizes=(5, 3), - base_channels=16, - max_channels=1024, - downsample_factors=(4, 4, 4, 4)): - super(MelganDiscriminator, self).__init__() + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_sizes=(5, 3), + base_channels=16, + max_channels=1024, + downsample_factors=(4, 4, 4, 4), + ): + super().__init__() self.layers = nn.ModuleList() layer_kernel_size = np.prod(kernel_sizes) @@ -21,31 +23,32 @@ class MelganDiscriminator(nn.Module): self.layers += [ nn.Sequential( nn.ReflectionPad1d(layer_padding), - weight_norm( - nn.Conv1d(in_channels, - base_channels, - layer_kernel_size, - stride=1)), nn.LeakyReLU(0.2, inplace=True)) + weight_norm(nn.Conv1d(in_channels, base_channels, layer_kernel_size, stride=1)), + nn.LeakyReLU(0.2, inplace=True), + ) ] # downsampling layers layer_in_channels = base_channels for downsample_factor in downsample_factors: - layer_out_channels = min(layer_in_channels * downsample_factor, - max_channels) + layer_out_channels = min(layer_in_channels * downsample_factor, max_channels) layer_kernel_size = downsample_factor * 10 + 1 layer_padding = (layer_kernel_size - 1) // 2 layer_groups = layer_in_channels // 4 self.layers += [ nn.Sequential( weight_norm( - nn.Conv1d(layer_in_channels, - layer_out_channels, - kernel_size=layer_kernel_size, - stride=downsample_factor, - padding=layer_padding, - groups=layer_groups)), - nn.LeakyReLU(0.2, inplace=True)) + nn.Conv1d( + layer_in_channels, + layer_out_channels, + kernel_size=layer_kernel_size, + stride=downsample_factor, + padding=layer_padding, + groups=layer_groups, + ) + ), + nn.LeakyReLU(0.2, inplace=True), + ) ] layer_in_channels = layer_out_channels @@ -55,19 +58,21 @@ class MelganDiscriminator(nn.Module): self.layers += [ nn.Sequential( weight_norm( - nn.Conv1d(layer_out_channels, - layer_out_channels, - kernel_size=kernel_sizes[0], - stride=1, - padding=layer_padding1)), + nn.Conv1d( + layer_out_channels, + layer_out_channels, + kernel_size=kernel_sizes[0], + stride=1, + padding=layer_padding1, + ) + ), nn.LeakyReLU(0.2, inplace=True), ), weight_norm( - nn.Conv1d(layer_out_channels, - out_channels, - kernel_size=kernel_sizes[1], - stride=1, - padding=layer_padding2)), + nn.Conv1d( + layer_out_channels, out_channels, kernel_size=kernel_sizes[1], stride=1, padding=layer_padding2 + ) + ), ] def forward(self, x): diff --git a/TTS/vocoder/models/melgan_generator.py b/TTS/vocoder/models/melgan_generator.py index 3070eac7..dabb4baa 100644 --- a/TTS/vocoder/models/melgan_generator.py +++ b/TTS/vocoder/models/melgan_generator.py @@ -6,19 +6,20 @@ from TTS.vocoder.layers.melgan import ResidualStack class MelganGenerator(nn.Module): - def __init__(self, - in_channels=80, - out_channels=1, - proj_kernel=7, - base_channels=512, - upsample_factors=(8, 8, 2, 2), - res_kernel=3, - num_res_blocks=3): - super(MelganGenerator, self).__init__() + def __init__( + self, + in_channels=80, + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=(8, 8, 2, 2), + res_kernel=3, + num_res_blocks=3, + ): + super().__init__() # assert model parameters - assert (proj_kernel - - 1) % 2 == 0, " [!] proj_kernel should be an odd number." + assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number." # setup additional model parameters base_padding = (proj_kernel - 1) // 2 @@ -29,18 +30,13 @@ class MelganGenerator(nn.Module): layers = [] layers += [ nn.ReflectionPad1d(base_padding), - weight_norm( - nn.Conv1d(in_channels, - base_channels, - kernel_size=proj_kernel, - stride=1, - bias=True)) + weight_norm(nn.Conv1d(in_channels, base_channels, kernel_size=proj_kernel, stride=1, bias=True)), ] # upsampling layers and residual stacks for idx, upsample_factor in enumerate(upsample_factors): - layer_in_channels = base_channels // (2**idx) - layer_out_channels = base_channels // (2**(idx + 1)) + layer_in_channels = base_channels // (2 ** idx) + layer_out_channels = base_channels // (2 ** (idx + 1)) layer_filter_size = upsample_factor * 2 layer_stride = upsample_factor layer_output_padding = upsample_factor % 2 @@ -48,18 +44,17 @@ class MelganGenerator(nn.Module): layers += [ nn.LeakyReLU(act_slope), weight_norm( - nn.ConvTranspose1d(layer_in_channels, - layer_out_channels, - layer_filter_size, - stride=layer_stride, - padding=layer_padding, - output_padding=layer_output_padding, - bias=True)), - ResidualStack( - channels=layer_out_channels, - num_res_blocks=num_res_blocks, - kernel_size=res_kernel - ) + nn.ConvTranspose1d( + layer_in_channels, + layer_out_channels, + layer_filter_size, + stride=layer_stride, + padding=layer_padding, + output_padding=layer_output_padding, + bias=True, + ) + ), + ResidualStack(channels=layer_out_channels, num_res_blocks=num_res_blocks, kernel_size=res_kernel), ] layers += [nn.LeakyReLU(act_slope)] @@ -67,13 +62,8 @@ class MelganGenerator(nn.Module): # final layer layers += [ nn.ReflectionPad1d(base_padding), - weight_norm( - nn.Conv1d(layer_out_channels, - out_channels, - proj_kernel, - stride=1, - bias=True)), - nn.Tanh() + weight_norm(nn.Conv1d(layer_out_channels, out_channels, proj_kernel, stride=1, bias=True)), + nn.Tanh(), ] self.layers = nn.Sequential(*layers) @@ -82,10 +72,7 @@ class MelganGenerator(nn.Module): def inference(self, c): c = c.to(self.layers[1].weight.device) - c = torch.nn.functional.pad( - c, - (self.inference_padding, self.inference_padding), - 'replicate') + c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") return self.layers(c) def remove_weight_norm(self): @@ -96,9 +83,11 @@ class MelganGenerator(nn.Module): except ValueError: layer.remove_weight_norm() - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) - self.load_state_dict(state['model']) + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) if eval: self.eval() assert not self.training diff --git a/TTS/vocoder/models/melgan_multiscale_discriminator.py b/TTS/vocoder/models/melgan_multiscale_discriminator.py index 0f9cca96..ce297b6d 100644 --- a/TTS/vocoder/models/melgan_multiscale_discriminator.py +++ b/TTS/vocoder/models/melgan_multiscale_discriminator.py @@ -4,31 +4,38 @@ from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator class MelganMultiscaleDiscriminator(nn.Module): - def __init__(self, - in_channels=1, - out_channels=1, - num_scales=3, - kernel_sizes=(5, 3), - base_channels=16, - max_channels=1024, - downsample_factors=(4, 4, 4), - pooling_kernel_size=4, - pooling_stride=2, - pooling_padding=1): - super(MelganMultiscaleDiscriminator, self).__init__() + def __init__( + self, + in_channels=1, + out_channels=1, + num_scales=3, + kernel_sizes=(5, 3), + base_channels=16, + max_channels=1024, + downsample_factors=(4, 4, 4), + pooling_kernel_size=4, + pooling_stride=2, + pooling_padding=1, + ): + super().__init__() - self.discriminators = nn.ModuleList([ - MelganDiscriminator(in_channels=in_channels, - out_channels=out_channels, - kernel_sizes=kernel_sizes, - base_channels=base_channels, - max_channels=max_channels, - downsample_factors=downsample_factors) - for _ in range(num_scales) - ]) - - self.pooling = nn.AvgPool1d(kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False) + self.discriminators = nn.ModuleList( + [ + MelganDiscriminator( + in_channels=in_channels, + out_channels=out_channels, + kernel_sizes=kernel_sizes, + base_channels=base_channels, + max_channels=max_channels, + downsample_factors=downsample_factors, + ) + for _ in range(num_scales) + ] + ) + self.pooling = nn.AvgPool1d( + kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False + ) def forward(self, x): scores = list() diff --git a/TTS/vocoder/models/multiband_melgan_generator.py b/TTS/vocoder/models/multiband_melgan_generator.py index 15e7426e..0caadc09 100644 --- a/TTS/vocoder/models/multiband_melgan_generator.py +++ b/TTS/vocoder/models/multiband_melgan_generator.py @@ -5,22 +5,25 @@ from TTS.vocoder.layers.pqmf import PQMF class MultibandMelganGenerator(MelganGenerator): - def __init__(self, - in_channels=80, - out_channels=4, - proj_kernel=7, - base_channels=384, - upsample_factors=(2, 8, 2, 2), - res_kernel=3, - num_res_blocks=3): - super(MultibandMelganGenerator, - self).__init__(in_channels=in_channels, - out_channels=out_channels, - proj_kernel=proj_kernel, - base_channels=base_channels, - upsample_factors=upsample_factors, - res_kernel=res_kernel, - num_res_blocks=num_res_blocks) + def __init__( + self, + in_channels=80, + out_channels=4, + proj_kernel=7, + base_channels=384, + upsample_factors=(2, 8, 2, 2), + res_kernel=3, + num_res_blocks=3, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + proj_kernel=proj_kernel, + base_channels=base_channels, + upsample_factors=upsample_factors, + res_kernel=res_kernel, + num_res_blocks=num_res_blocks, + ) self.pqmf_layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0) def pqmf_analysis(self, x): @@ -33,7 +36,6 @@ class MultibandMelganGenerator(MelganGenerator): def inference(self, cond_features): cond_features = cond_features.to(self.layers[1].weight.device) cond_features = torch.nn.functional.pad( - cond_features, - (self.inference_padding, self.inference_padding), - 'replicate') + cond_features, (self.inference_padding, self.inference_padding), "replicate" + ) return self.pqmf_synthesis(self.layers(cond_features)) diff --git a/TTS/vocoder/models/parallel_wavegan_discriminator.py b/TTS/vocoder/models/parallel_wavegan_discriminator.py index 37c22695..7414c233 100644 --- a/TTS/vocoder/models/parallel_wavegan_discriminator.py +++ b/TTS/vocoder/models/parallel_wavegan_discriminator.py @@ -11,19 +11,21 @@ class ParallelWaveganDiscriminator(nn.Module): of predictions. It is a stack of convolutional blocks with dilation. """ + # pylint: disable=dangerous-default-value - def __init__(self, - in_channels=1, - out_channels=1, - kernel_size=3, - num_layers=10, - conv_channels=64, - dilation_factor=1, - nonlinear_activation="LeakyReLU", - nonlinear_activation_params={"negative_slope": 0.2}, - bias=True, - ): - super(ParallelWaveganDiscriminator, self).__init__() + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=10, + conv_channels=64, + dilation_factor=1, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + bias=True, + ): + super().__init__() assert (kernel_size - 1) % 2 == 0, " [!] does not support even number kernel size." assert dilation_factor > 0, " [!] dilation factor must be > 0." self.conv_layers = nn.ModuleList() @@ -36,21 +38,19 @@ class ParallelWaveganDiscriminator(nn.Module): conv_in_channels = conv_channels padding = (kernel_size - 1) // 2 * dilation conv_layer = [ - nn.Conv1d(conv_in_channels, - conv_channels, - kernel_size=kernel_size, - padding=padding, - dilation=dilation, - bias=bias), - getattr(nn, - nonlinear_activation)(inplace=True, - **nonlinear_activation_params) + nn.Conv1d( + conv_in_channels, + conv_channels, + kernel_size=kernel_size, + padding=padding, + dilation=dilation, + bias=bias, + ), + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), ] self.conv_layers += conv_layer padding = (kernel_size - 1) // 2 - last_conv_layer = nn.Conv1d( - conv_in_channels, out_channels, - kernel_size=kernel_size, padding=padding, bias=bias) + last_conv_layer = nn.Conv1d(conv_in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=bias) self.conv_layers += [last_conv_layer] self.apply_weight_norm() @@ -68,6 +68,7 @@ class ParallelWaveganDiscriminator(nn.Module): def _apply_weight_norm(m): if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): torch.nn.utils.weight_norm(m) + self.apply(_apply_weight_norm) def remove_weight_norm(self): @@ -77,26 +78,28 @@ class ParallelWaveganDiscriminator(nn.Module): nn.utils.remove_weight_norm(m) except ValueError: # this module didn't have weight norm return + self.apply(_remove_weight_norm) class ResidualParallelWaveganDiscriminator(nn.Module): # pylint: disable=dangerous-default-value - def __init__(self, - in_channels=1, - out_channels=1, - kernel_size=3, - num_layers=30, - stacks=3, - res_channels=64, - gate_channels=128, - skip_channels=64, - dropout=0.0, - bias=True, - nonlinear_activation="LeakyReLU", - nonlinear_activation_params={"negative_slope": 0.2}, - ): - super(ResidualParallelWaveganDiscriminator, self).__init__() + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_size=3, + num_layers=30, + stacks=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + dropout=0.0, + bias=True, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + ): + super().__init__() assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." self.in_channels = in_channels @@ -112,14 +115,8 @@ class ResidualParallelWaveganDiscriminator(nn.Module): # define first convolution self.first_conv = nn.Sequential( - nn.Conv1d(in_channels, - res_channels, - kernel_size=1, - padding=0, - dilation=1, - bias=True), - getattr(nn, nonlinear_activation)(inplace=True, - **nonlinear_activation_params), + nn.Conv1d(in_channels, res_channels, kernel_size=1, padding=0, dilation=1, bias=True), + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), ) # define residual blocks @@ -140,24 +137,14 @@ class ResidualParallelWaveganDiscriminator(nn.Module): self.conv_layers += [conv] # define output layers - self.last_conv_layers = nn.ModuleList([ - getattr(nn, nonlinear_activation)(inplace=True, - **nonlinear_activation_params), - nn.Conv1d(skip_channels, - skip_channels, - kernel_size=1, - padding=0, - dilation=1, - bias=True), - getattr(nn, nonlinear_activation)(inplace=True, - **nonlinear_activation_params), - nn.Conv1d(skip_channels, - out_channels, - kernel_size=1, - padding=0, - dilation=1, - bias=True), - ]) + self.last_conv_layers = nn.ModuleList( + [ + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), + nn.Conv1d(skip_channels, skip_channels, kernel_size=1, padding=0, dilation=1, bias=True), + getattr(nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params), + nn.Conv1d(skip_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=True), + ] + ) # apply weight norm self.apply_weight_norm() @@ -184,6 +171,7 @@ class ResidualParallelWaveganDiscriminator(nn.Module): def _apply_weight_norm(m): if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d)): torch.nn.utils.weight_norm(m) + self.apply(_apply_weight_norm) def remove_weight_norm(self): diff --git a/TTS/vocoder/models/parallel_wavegan_generator.py b/TTS/vocoder/models/parallel_wavegan_generator.py index 1d1bcdcb..c9a84a0e 100644 --- a/TTS/vocoder/models/parallel_wavegan_generator.py +++ b/TTS/vocoder/models/parallel_wavegan_generator.py @@ -12,24 +12,27 @@ class ParallelWaveganGenerator(torch.nn.Module): It is conditioned on an aux feature (spectrogram) to generate an output waveform from an input noise. """ - # pylint: disable=dangerous-default-value - def __init__(self, - in_channels=1, - out_channels=1, - kernel_size=3, - num_res_blocks=30, - stacks=3, - res_channels=64, - gate_channels=128, - skip_channels=64, - aux_channels=80, - dropout=0.0, - bias=True, - use_weight_norm=True, - upsample_factors=[4, 4, 4, 4], - inference_padding=2): - super(ParallelWaveganGenerator, self).__init__() + # pylint: disable=dangerous-default-value + def __init__( + self, + in_channels=1, + out_channels=1, + kernel_size=3, + num_res_blocks=30, + stacks=3, + res_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=80, + dropout=0.0, + bias=True, + use_weight_norm=True, + upsample_factors=[4, 4, 4, 4], + inference_padding=2, + ): + + super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.aux_channels = aux_channels @@ -46,10 +49,7 @@ class ParallelWaveganGenerator(torch.nn.Module): layers_per_stack = num_res_blocks // stacks # define first convolution - self.first_conv = torch.nn.Conv1d(in_channels, - res_channels, - kernel_size=1, - bias=True) + self.first_conv = torch.nn.Conv1d(in_channels, res_channels, kernel_size=1, bias=True) # define conv + upsampling network self.upsample_net = ConvUpsample(upsample_factors=upsample_factors) @@ -57,7 +57,7 @@ class ParallelWaveganGenerator(torch.nn.Module): # define residual blocks self.conv_layers = torch.nn.ModuleList() for layer in range(num_res_blocks): - dilation = 2**(layer % layers_per_stack) + dilation = 2 ** (layer % layers_per_stack) conv = ResidualBlock( kernel_size=kernel_size, res_channels=res_channels, @@ -71,18 +71,14 @@ class ParallelWaveganGenerator(torch.nn.Module): self.conv_layers += [conv] # define output layers - self.last_conv_layers = torch.nn.ModuleList([ - torch.nn.ReLU(inplace=True), - torch.nn.Conv1d(skip_channels, - skip_channels, - kernel_size=1, - bias=True), - torch.nn.ReLU(inplace=True), - torch.nn.Conv1d(skip_channels, - out_channels, - kernel_size=1, - bias=True), - ]) + self.last_conv_layers = torch.nn.ModuleList( + [ + torch.nn.ReLU(inplace=True), + torch.nn.Conv1d(skip_channels, skip_channels, kernel_size=1, bias=True), + torch.nn.ReLU(inplace=True), + torch.nn.Conv1d(skip_channels, out_channels, kernel_size=1, bias=True), + ] + ) # apply weight norm if use_weight_norm: @@ -90,8 +86,8 @@ class ParallelWaveganGenerator(torch.nn.Module): def forward(self, c): """ - c: (B, C ,T'). - o: Output tensor (B, out_channels, T) + c: (B, C ,T'). + o: Output tensor (B, out_channels, T) """ # random noise x = torch.randn([c.shape[0], 1, c.shape[2] * self.upsample_scale]) @@ -100,8 +96,9 @@ class ParallelWaveganGenerator(torch.nn.Module): # perform upsampling if c is not None and self.upsample_net is not None: c = self.upsample_net(c) - assert c.shape[-1] == x.shape[ - -1], f" [!] Upsampling scale does not match the expected output. {c.shape} vs {x.shape}" + assert ( + c.shape[-1] == x.shape[-1] + ), f" [!] Upsampling scale does not match the expected output. {c.shape} vs {x.shape}" # encode to hidden representation x = self.first_conv(x) @@ -121,8 +118,7 @@ class ParallelWaveganGenerator(torch.nn.Module): @torch.no_grad() def inference(self, c): c = c.to(self.first_conv.weight.device) - c = torch.nn.functional.pad( - c, (self.inference_padding, self.inference_padding), 'replicate') + c = torch.nn.functional.pad(c, (self.inference_padding, self.inference_padding), "replicate") return self.forward(c) def remove_weight_norm(self): @@ -144,10 +140,7 @@ class ParallelWaveganGenerator(torch.nn.Module): self.apply(_apply_weight_norm) @staticmethod - def _get_receptive_field_size(layers, - stacks, - kernel_size, - dilation=lambda x: 2**x): + def _get_receptive_field_size(layers, stacks, kernel_size, dilation=lambda x: 2 ** x): assert layers % stacks == 0 layers_per_cycle = layers // stacks dilations = [dilation(i % layers_per_cycle) for i in range(layers)] @@ -155,12 +148,13 @@ class ParallelWaveganGenerator(torch.nn.Module): @property def receptive_field_size(self): - return self._get_receptive_field_size(self.layers, self.stacks, - self.kernel_size) + return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) - self.load_state_dict(state['model']) + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) if eval: self.eval() assert not self.training diff --git a/TTS/vocoder/models/random_window_discriminator.py b/TTS/vocoder/models/random_window_discriminator.py index 3efd395e..ea95668a 100644 --- a/TTS/vocoder/models/random_window_discriminator.py +++ b/TTS/vocoder/models/random_window_discriminator.py @@ -4,7 +4,7 @@ from torch import nn class GBlock(nn.Module): def __init__(self, in_channels, cond_channels, downsample_factor): - super(GBlock, self).__init__() + super().__init__() self.in_channels = in_channels self.cond_channels = cond_channels @@ -13,20 +13,16 @@ class GBlock(nn.Module): self.start = nn.Sequential( nn.AvgPool1d(downsample_factor, stride=downsample_factor), nn.ReLU(), - nn.Conv1d(in_channels, in_channels * 2, kernel_size=3, padding=1)) - self.lc_conv1d = nn.Conv1d(cond_channels, - in_channels * 2, - kernel_size=1) + nn.Conv1d(in_channels, in_channels * 2, kernel_size=3, padding=1), + ) + self.lc_conv1d = nn.Conv1d(cond_channels, in_channels * 2, kernel_size=1) self.end = nn.Sequential( - nn.ReLU(), - nn.Conv1d(in_channels * 2, - in_channels * 2, - kernel_size=3, - dilation=2, - padding=2)) + nn.ReLU(), nn.Conv1d(in_channels * 2, in_channels * 2, kernel_size=3, dilation=2, padding=2) + ) self.residual = nn.Sequential( nn.Conv1d(in_channels, in_channels * 2, kernel_size=1), - nn.AvgPool1d(downsample_factor, stride=downsample_factor)) + nn.AvgPool1d(downsample_factor, stride=downsample_factor), + ) def forward(self, inputs, conditions): outputs = self.start(inputs) + self.lc_conv1d(conditions) @@ -39,42 +35,34 @@ class GBlock(nn.Module): class DBlock(nn.Module): def __init__(self, in_channels, out_channels, downsample_factor): - super(DBlock, self).__init__() + super().__init__() self.in_channels = in_channels self.downsample_factor = downsample_factor self.out_channels = out_channels - self.donwsample_layer = nn.AvgPool1d(downsample_factor, - stride=downsample_factor) + self.donwsample_layer = nn.AvgPool1d(downsample_factor, stride=downsample_factor) self.layers = nn.Sequential( nn.ReLU(), nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(), - nn.Conv1d(out_channels, - out_channels, - kernel_size=3, - dilation=2, - padding=2)) + nn.Conv1d(out_channels, out_channels, kernel_size=3, dilation=2, padding=2), + ) self.residual = nn.Sequential( - nn.Conv1d(in_channels, out_channels, kernel_size=1), ) + nn.Conv1d(in_channels, out_channels, kernel_size=1), + ) def forward(self, inputs): if self.downsample_factor > 1: - outputs = self.layers(self.donwsample_layer(inputs))\ - + self.donwsample_layer(self.residual(inputs)) + outputs = self.layers(self.donwsample_layer(inputs)) + self.donwsample_layer(self.residual(inputs)) else: outputs = self.layers(inputs) + self.residual(inputs) return outputs class ConditionalDiscriminator(nn.Module): - def __init__(self, - in_channels, - cond_channels, - downsample_factors=(2, 2, 2), - out_channels=(128, 256)): - super(ConditionalDiscriminator, self).__init__() + def __init__(self, in_channels, cond_channels, downsample_factors=(2, 2, 2), out_channels=(128, 256)): + super().__init__() assert len(downsample_factors) == len(out_channels) + 1 @@ -90,13 +78,11 @@ class ConditionalDiscriminator(nn.Module): self.pre_cond_layers += [DBlock(in_channels, 64, 1)] in_channels = 64 for (i, channel) in enumerate(out_channels): - self.pre_cond_layers.append( - DBlock(in_channels, channel, downsample_factors[i])) + self.pre_cond_layers.append(DBlock(in_channels, channel, downsample_factors[i])) in_channels = channel # condition block - self.cond_block = GBlock(in_channels, cond_channels, - downsample_factors[-1]) + self.cond_block = GBlock(in_channels, cond_channels, downsample_factors[-1]) # layers after condition block self.post_cond_layers += [ @@ -119,12 +105,8 @@ class ConditionalDiscriminator(nn.Module): class UnconditionalDiscriminator(nn.Module): - def __init__(self, - in_channels, - base_channels=64, - downsample_factors=(8, 4), - out_channels=(128, 256)): - super(UnconditionalDiscriminator, self).__init__() + def __init__(self, in_channels, base_channels=64, downsample_factors=(8, 4), out_channels=(128, 256)): + super().__init__() self.downsample_factors = downsample_factors self.in_channels = in_channels @@ -155,17 +137,18 @@ class UnconditionalDiscriminator(nn.Module): class RandomWindowDiscriminator(nn.Module): """Random Window Discriminator as described in http://arxiv.org/abs/1909.11646""" - def __init__(self, - cond_channels, - hop_length, - uncond_disc_donwsample_factors=(8, 4), - cond_disc_downsample_factors=((8, 4, 2, 2, 2), (8, 4, 2, 2), - (8, 4, 2), (8, 4), (4, 2, 2)), - cond_disc_out_channels=((128, 128, 256, 256), (128, 256, 256), - (128, 256), (256, ), (128, 256)), - window_sizes=(512, 1024, 2048, 4096, 8192)): - super(RandomWindowDiscriminator, self).__init__() + def __init__( + self, + cond_channels, + hop_length, + uncond_disc_donwsample_factors=(8, 4), + cond_disc_downsample_factors=((8, 4, 2, 2, 2), (8, 4, 2, 2), (8, 4, 2), (8, 4), (4, 2, 2)), + cond_disc_out_channels=((128, 128, 256, 256), (128, 256, 256), (128, 256), (256,), (128, 256)), + window_sizes=(512, 1024, 2048, 4096, 8192), + ): + + super().__init__() self.cond_channels = cond_channels self.window_sizes = window_sizes self.hop_length = hop_length @@ -173,8 +156,7 @@ class RandomWindowDiscriminator(nn.Module): self.ks = [ws // self.base_window_size for ws in window_sizes] # check arguments - assert len(cond_disc_downsample_factors) == len( - cond_disc_out_channels) == len(window_sizes) + assert len(cond_disc_downsample_factors) == len(cond_disc_out_channels) == len(window_sizes) for ws in window_sizes: assert ws % hop_length == 0 @@ -185,9 +167,8 @@ class RandomWindowDiscriminator(nn.Module): self.unconditional_discriminators = nn.ModuleList([]) for k in self.ks: layer = UnconditionalDiscriminator( - in_channels=k, - base_channels=64, - downsample_factors=uncond_disc_donwsample_factors) + in_channels=k, base_channels=64, downsample_factors=uncond_disc_donwsample_factors + ) self.unconditional_discriminators.append(layer) self.conditional_discriminators = nn.ModuleList([]) @@ -196,29 +177,27 @@ class RandomWindowDiscriminator(nn.Module): in_channels=k, cond_channels=cond_channels, downsample_factors=cond_disc_downsample_factors[idx], - out_channels=cond_disc_out_channels[idx]) + out_channels=cond_disc_out_channels[idx], + ) self.conditional_discriminators.append(layer) def forward(self, x, c): scores = [] feats = [] # unconditional pass - for (window_size, layer) in zip(self.window_sizes, - self.unconditional_discriminators): + for (window_size, layer) in zip(self.window_sizes, self.unconditional_discriminators): index = np.random.randint(x.shape[-1] - window_size) - score = layer(x[:, :, index:index + window_size]) + score = layer(x[:, :, index : index + window_size]) scores.append(score) # conditional pass - for (window_size, layer) in zip(self.window_sizes, - self.conditional_discriminators): + for (window_size, layer) in zip(self.window_sizes, self.conditional_discriminators): frame_size = window_size // self.hop_length lc_index = np.random.randint(c.shape[-1] - frame_size) sample_index = lc_index * self.hop_length - x_sub = x[:, :, - sample_index:(lc_index + frame_size) * self.hop_length] - c_sub = c[:, :, lc_index:lc_index + frame_size] + x_sub = x[:, :, sample_index : (lc_index + frame_size) * self.hop_length] + c_sub = c[:, :, lc_index : lc_index + frame_size] score = layer(x_sub, c_sub) scores.append(score) diff --git a/TTS/vocoder/models/wavegrad.py b/TTS/vocoder/models/wavegrad.py index 96951ad1..ef8e8add 100644 --- a/TTS/vocoder/models/wavegrad.py +++ b/TTS/vocoder/models/wavegrad.py @@ -8,17 +8,18 @@ from ..layers.wavegrad import DBlock, FiLM, UBlock, Conv1d class Wavegrad(nn.Module): # pylint: disable=dangerous-default-value - def __init__(self, - in_channels=80, - out_channels=1, - use_weight_norm=False, - y_conv_channels=32, - x_conv_channels=768, - dblock_out_channels=[128, 128, 256, 512], - ublock_out_channels=[512, 512, 256, 128, 128], - upsample_factors=[5, 5, 3, 2, 2], - upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], - [1, 2, 4, 8], [1, 2, 4, 8]]): + def __init__( + self, + in_channels=80, + out_channels=1, + use_weight_norm=False, + y_conv_channels=32, + x_conv_channels=768, + dblock_out_channels=[128, 128, 256, 512], + ublock_out_channels=[512, 512, 256, 128, 128], + upsample_factors=[5, 5, 3, 2, 2], + upsample_dilations=[[1, 2, 1, 2], [1, 2, 1, 2], [1, 2, 4, 8], [1, 2, 4, 8], [1, 2, 4, 8]], + ): super().__init__() self.use_weight_norm = use_weight_norm @@ -72,14 +73,13 @@ class Wavegrad(nn.Module): shift_and_scale.append(film(x, noise_scale)) x = self.x_conv(spectrogram) - for layer, (film_shift, film_scale) in zip(self.ublocks, - reversed(shift_and_scale)): + for layer, (film_shift, film_scale) in zip(self.ublocks, reversed(shift_and_scale)): x = layer(x, film_shift, film_scale) x = self.out_conv(x) return x def load_noise_schedule(self, path): - beta = np.load(path, allow_pickle=True).item()['beta'] # pylint: disable=unexpected-keyword-arg + beta = np.load(path, allow_pickle=True).item()["beta"] # pylint: disable=unexpected-keyword-arg self.compute_noise_level(beta) @torch.no_grad() @@ -91,26 +91,24 @@ class Wavegrad(nn.Module): y_n = torch.FloatTensor(y_n).unsqueeze(0).unsqueeze(0).to(x) sqrt_alpha_hat = self.noise_level.to(x) for n in range(len(self.alpha) - 1, -1, -1): - y_n = self.c1[n] * (y_n - self.c2[n] * self.forward( - y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0]))) + y_n = self.c1[n] * (y_n - self.c2[n] * self.forward(y_n, x, sqrt_alpha_hat[n].repeat(x.shape[0]))) if n > 0: z = torch.randn_like(y_n) y_n += self.sigma[n - 1] * z y_n.clamp_(-1.0, 1.0) return y_n - def compute_y_n(self, y_0): """Compute noisy audio based on noise schedule""" self.noise_level = self.noise_level.to(y_0) if len(y_0.shape) == 3: y_0 = y_0.squeeze(1) s = torch.randint(0, self.num_steps - 1, [y_0.shape[0]]) - l_a, l_b = self.noise_level[s], self.noise_level[s+1] + l_a, l_b = self.noise_level[s], self.noise_level[s + 1] noise_scale = l_a + torch.rand(y_0.shape[0]).to(y_0) * (l_b - l_a) noise_scale = noise_scale.unsqueeze(1) noise = torch.randn_like(y_0) - noisy_audio = noise_scale * y_0 + (1.0 - noise_scale**2)**0.5 * noise + noisy_audio = noise_scale * y_0 + (1.0 - noise_scale ** 2) ** 0.5 * noise return noise.unsqueeze(1), noisy_audio.unsqueeze(1), noise_scale[:, 0] def compute_noise_level(self, beta): @@ -127,9 +125,9 @@ class Wavegrad(nn.Module): self.alpha_hat = torch.tensor(alpha_hat.astype(np.float32)) self.noise_level = torch.tensor(noise_level.astype(np.float32)) - self.c1 = 1 / self.alpha**0.5 - self.c2 = (1 - self.alpha) / (1 - self.alpha_hat)**0.5 - self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:])**0.5 + self.c1 = 1 / self.alpha ** 0.5 + self.c2 = (1 - self.alpha) / (1 - self.alpha_hat) ** 0.5 + self.sigma = ((1.0 - self.alpha_hat[:-1]) / (1.0 - self.alpha_hat[1:]) * self.beta[1:]) ** 0.5 def remove_weight_norm(self): for _, layer in enumerate(self.dblocks): @@ -146,7 +144,6 @@ class Wavegrad(nn.Module): except ValueError: layer.remove_weight_norm() - for _, layer in enumerate(self.ublocks): if len(layer.state_dict()) != 0: try: @@ -167,7 +164,6 @@ class Wavegrad(nn.Module): if len(layer.state_dict()) != 0: layer.apply_weight_norm() - for _, layer in enumerate(self.ublocks): if len(layer.state_dict()) != 0: layer.apply_weight_norm() @@ -176,21 +172,26 @@ class Wavegrad(nn.Module): self.out_conv = weight_norm(self.out_conv) self.y_conv = weight_norm(self.y_conv) - - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) - self.load_state_dict(state['model']) + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) if eval: self.eval() assert not self.training if self.use_weight_norm: self.remove_weight_norm() - betas = np.linspace(config['test_noise_schedule']['min_val'], - config['test_noise_schedule']['max_val'], - config['test_noise_schedule']['num_steps']) + betas = np.linspace( + config["test_noise_schedule"]["min_val"], + config["test_noise_schedule"]["max_val"], + config["test_noise_schedule"]["num_steps"], + ) self.compute_noise_level(betas) else: - betas = np.linspace(config['train_noise_schedule']['min_val'], - config['train_noise_schedule']['max_val'], - config['train_noise_schedule']['num_steps']) + betas = np.linspace( + config["train_noise_schedule"]["min_val"], + config["train_noise_schedule"]["max_val"], + config["train_noise_schedule"]["num_steps"], + ) self.compute_noise_level(betas) diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index dbcaea66..ca4ea3f8 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -16,6 +16,7 @@ from TTS.vocoder.utils.distribution import ( def stream(string, variables): sys.stdout.write(f"\r{string}" % variables) + # pylint: disable=abstract-method # relates https://github.com/pytorch/pytorch/issues/42305 class ResBlock(nn.Module): @@ -40,8 +41,7 @@ class MelResNet(nn.Module): def __init__(self, num_res_blocks, in_dims, compute_dims, res_out_dims, pad): super().__init__() k_size = pad * 2 + 1 - self.conv_in = nn.Conv1d( - in_dims, compute_dims, kernel_size=k_size, bias=False) + self.conv_in = nn.Conv1d(in_dims, compute_dims, kernel_size=k_size, bias=False) self.batch_norm = nn.BatchNorm1d(compute_dims) self.layers = nn.ModuleList() for _ in range(num_res_blocks): @@ -73,31 +73,28 @@ class Stretch2d(nn.Module): class UpsampleNetwork(nn.Module): def __init__( - self, - feat_dims, - upsample_scales, - compute_dims, - num_res_blocks, - res_out_dims, - pad, - use_aux_net, - ): + self, + feat_dims, + upsample_scales, + compute_dims, + num_res_blocks, + res_out_dims, + pad, + use_aux_net, + ): super().__init__() self.total_scale = np.cumproduct(upsample_scales)[-1] self.indent = pad * self.total_scale self.use_aux_net = use_aux_net if use_aux_net: - self.resnet = MelResNet( - num_res_blocks, feat_dims, compute_dims, res_out_dims, pad - ) + self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad) self.resnet_stretch = Stretch2d(self.total_scale, 1) self.up_layers = nn.ModuleList() for scale in upsample_scales: k_size = (1, scale * 2 + 1) padding = (0, scale) stretch = Stretch2d(scale, 1) - conv = nn.Conv2d(1, 1, kernel_size=k_size, - padding=padding, bias=False) + conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False) conv.weight.data.fill_(1.0 / k_size[1]) self.up_layers.append(stretch) self.up_layers.append(conv) @@ -113,56 +110,51 @@ class UpsampleNetwork(nn.Module): m = m.unsqueeze(1) for f in self.up_layers: m = f(m) - m = m.squeeze(1)[:, :, self.indent: -self.indent] + m = m.squeeze(1)[:, :, self.indent : -self.indent] return m.transpose(1, 2), aux class Upsample(nn.Module): - def __init__(self, scale, pad, num_res_blocks, feat_dims, compute_dims, - res_out_dims, use_aux_net): + def __init__(self, scale, pad, num_res_blocks, feat_dims, compute_dims, res_out_dims, use_aux_net): super().__init__() self.scale = scale self.pad = pad self.indent = pad * scale self.use_aux_net = use_aux_net - self.resnet = MelResNet(num_res_blocks, feat_dims, - compute_dims, res_out_dims, pad) + self.resnet = MelResNet(num_res_blocks, feat_dims, compute_dims, res_out_dims, pad) def forward(self, m): if self.use_aux_net: aux = self.resnet(m) - aux = torch.nn.functional.interpolate( - aux, scale_factor=self.scale, mode="linear", align_corners=True - ) + aux = torch.nn.functional.interpolate(aux, scale_factor=self.scale, mode="linear", align_corners=True) aux = aux.transpose(1, 2) else: aux = None - m = torch.nn.functional.interpolate( - m, scale_factor=self.scale, mode="linear", align_corners=True - ) - m = m[:, :, self.indent: -self.indent] + m = torch.nn.functional.interpolate(m, scale_factor=self.scale, mode="linear", align_corners=True) + m = m[:, :, self.indent : -self.indent] m = m * 0.045 # empirically found return m.transpose(1, 2), aux class WaveRNN(nn.Module): - def __init__(self, - rnn_dims, - fc_dims, - mode, - mulaw, - pad, - use_aux_net, - use_upsample_net, - upsample_factors, - feat_dims, - compute_dims, - res_out_dims, - num_res_blocks, - hop_length, - sample_rate, - ): + def __init__( + self, + rnn_dims, + fc_dims, + mode, + mulaw, + pad, + use_aux_net, + use_upsample_net, + upsample_factors, + feat_dims, + compute_dims, + res_out_dims, + num_res_blocks, + hop_length, + sample_rate, + ): super().__init__() self.mode = mode self.mulaw = mulaw @@ -209,8 +201,7 @@ class WaveRNN(nn.Module): if self.use_aux_net: self.I = nn.Linear(feat_dims + self.aux_dims + 1, rnn_dims) self.rnn1 = nn.GRU(rnn_dims, rnn_dims, batch_first=True) - self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, - rnn_dims, batch_first=True) + self.rnn2 = nn.GRU(rnn_dims + self.aux_dims, rnn_dims, batch_first=True) self.fc1 = nn.Linear(rnn_dims + self.aux_dims, fc_dims) self.fc2 = nn.Linear(fc_dims + self.aux_dims, fc_dims) self.fc3 = nn.Linear(fc_dims, self.n_classes) @@ -230,10 +221,10 @@ class WaveRNN(nn.Module): if self.use_aux_net: aux_idx = [self.aux_dims * i for i in range(5)] - a1 = aux[:, :, aux_idx[0]: aux_idx[1]] - a2 = aux[:, :, aux_idx[1]: aux_idx[2]] - a3 = aux[:, :, aux_idx[2]: aux_idx[3]] - a4 = aux[:, :, aux_idx[3]: aux_idx[4]] + a1 = aux[:, :, aux_idx[0] : aux_idx[1]] + a2 = aux[:, :, aux_idx[1] : aux_idx[2]] + a3 = aux[:, :, aux_idx[2] : aux_idx[3]] + a4 = aux[:, :, aux_idx[3] : aux_idx[4]] x = ( torch.cat([x.unsqueeze(-1), mels, a1], dim=2) @@ -276,8 +267,7 @@ class WaveRNN(nn.Module): mels = mels.unsqueeze(0) wave_len = (mels.size(-1) - 1) * self.hop_length - mels = self.pad_tensor(mels.transpose( - 1, 2), pad=self.pad, side="both") + mels = self.pad_tensor(mels.transpose(1, 2), pad=self.pad, side="both") mels, aux = self.upsample(mels.transpose(1, 2)) if batched: @@ -293,7 +283,7 @@ class WaveRNN(nn.Module): if self.use_aux_net: d = self.aux_dims - aux_split = [aux[:, :, d * i: d * (i + 1)] for i in range(4)] + aux_split = [aux[:, :, d * i : d * (i + 1)] for i in range(4)] for i in range(seq_len): @@ -302,11 +292,7 @@ class WaveRNN(nn.Module): if self.use_aux_net: a1_t, a2_t, a3_t, a4_t = (a[:, i, :] for a in aux_split) - x = ( - torch.cat([x, m_t, a1_t], dim=1) - if self.use_aux_net - else torch.cat([x, m_t], dim=1) - ) + x = torch.cat([x, m_t, a1_t], dim=1) if self.use_aux_net else torch.cat([x, m_t], dim=1) x = self.I(x) h1 = rnn1(x, h1) @@ -324,14 +310,11 @@ class WaveRNN(nn.Module): logits = self.fc3(x) if self.mode == "mold": - sample = sample_from_discretized_mix_logistic( - logits.unsqueeze(0).transpose(1, 2) - ) + sample = sample_from_discretized_mix_logistic(logits.unsqueeze(0).transpose(1, 2)) output.append(sample.view(-1)) x = sample.transpose(0, 1).to(device) elif self.mode == "gauss": - sample = sample_from_gaussian( - logits.unsqueeze(0).transpose(1, 2)) + sample = sample_from_gaussian(logits.unsqueeze(0).transpose(1, 2)) output.append(sample.view(-1)) x = sample.transpose(0, 1).to(device) elif isinstance(self.mode, int): @@ -342,8 +325,7 @@ class WaveRNN(nn.Module): output.append(sample) x = sample.unsqueeze(-1) else: - raise RuntimeError( - "Unknown model mode value - ", self.mode) + raise RuntimeError("Unknown model mode value - ", self.mode) if i % 100 == 0: self.gen_display(i, seq_len, b_size, start) @@ -366,7 +348,7 @@ class WaveRNN(nn.Module): output = output[:wave_len] if wave_len > len(fade_out): - output[-20 * self.hop_length:] *= fade_out + output[-20 * self.hop_length :] *= fade_out self.train() return output @@ -411,8 +393,7 @@ class WaveRNN(nn.Module): padding = target + 2 * overlap - remaining x = self.pad_tensor(x, padding, side="after") - folded = torch.zeros(num_folds, target + 2 * - overlap, features).to(x.device) + folded = torch.zeros(num_folds, target + 2 * overlap, features).to(x.device) # Get the values for the folded tensor for i in range(num_folds): @@ -439,7 +420,7 @@ class WaveRNN(nn.Module): total = t + 2 * pad if side == "both" else t + pad padded = torch.zeros(b, total, c).to(x.device) if side in ("before", "both"): - padded[:, pad: pad + t, :] = x + padded[:, pad : pad + t, :] = x elif side == "after": padded[:, :t, :] = x return padded @@ -500,9 +481,11 @@ class WaveRNN(nn.Module): return unfolded - def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) - self.load_state_dict(state['model']) + def load_checkpoint( + self, config, checkpoint_path, eval=False + ): # pylint: disable=unused-argument, redefined-builtin + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.load_state_dict(state["model"]) if eval: self.eval() assert not self.training diff --git a/TTS/vocoder/tf/layers/melgan.py b/TTS/vocoder/tf/layers/melgan.py index 34b25d65..90bce6f1 100644 --- a/TTS/vocoder/tf/layers/melgan.py +++ b/TTS/vocoder/tf/layers/melgan.py @@ -3,7 +3,7 @@ import tensorflow as tf class ReflectionPad1d(tf.keras.layers.Layer): def __init__(self, padding): - super(ReflectionPad1d, self).__init__() + super().__init__() self.padding = padding def call(self, x): @@ -12,7 +12,7 @@ class ReflectionPad1d(tf.keras.layers.Layer): class ResidualStack(tf.keras.layers.Layer): def __init__(self, channels, num_res_blocks, kernel_size, name): - super(ResidualStack, self).__init__(name=name) + super().__init__(name=name) assert (kernel_size - 1) % 2 == 0, " [!] kernel_size has to be odd." base_padding = (kernel_size - 1) // 2 @@ -21,29 +21,27 @@ class ResidualStack(tf.keras.layers.Layer): num_layers = 2 for idx in range(num_res_blocks): layer_kernel_size = kernel_size - layer_dilation = layer_kernel_size**idx + layer_dilation = layer_kernel_size ** idx layer_padding = base_padding * layer_dilation block = [ tf.keras.layers.LeakyReLU(0.2), ReflectionPad1d(layer_padding), - tf.keras.layers.Conv2D(filters=channels, - kernel_size=(kernel_size, 1), - dilation_rate=(layer_dilation, 1), - use_bias=True, - padding='valid', - name=f'blocks.{idx}.{num_layers}'), + tf.keras.layers.Conv2D( + filters=channels, + kernel_size=(kernel_size, 1), + dilation_rate=(layer_dilation, 1), + use_bias=True, + padding="valid", + name=f"blocks.{idx}.{num_layers}", + ), tf.keras.layers.LeakyReLU(0.2), - tf.keras.layers.Conv2D(filters=channels, - kernel_size=(1, 1), - use_bias=True, - name=f'blocks.{idx}.{num_layers + 2}') + tf.keras.layers.Conv2D( + filters=channels, kernel_size=(1, 1), use_bias=True, name=f"blocks.{idx}.{num_layers + 2}" + ), ] self.blocks.append(block) self.shortcuts = [ - tf.keras.layers.Conv2D(channels, - kernel_size=1, - use_bias=True, - name=f'shortcuts.{i}') + tf.keras.layers.Conv2D(channels, kernel_size=1, use_bias=True, name=f"shortcuts.{i}") for i in range(num_res_blocks) ] diff --git a/TTS/vocoder/tf/layers/pqmf.py b/TTS/vocoder/tf/layers/pqmf.py index c018971f..e1b5055a 100644 --- a/TTS/vocoder/tf/layers/pqmf.py +++ b/TTS/vocoder/tf/layers/pqmf.py @@ -6,28 +6,26 @@ from scipy import signal as sig class PQMF(tf.keras.layers.Layer): def __init__(self, N=4, taps=62, cutoff=0.15, beta=9.0): - super(PQMF, self).__init__() + super().__init__() # define filter coefficient self.N = N self.taps = taps self.cutoff = cutoff self.beta = beta - QMF = sig.firwin(taps + 1, cutoff, window=('kaiser', beta)) + QMF = sig.firwin(taps + 1, cutoff, window=("kaiser", beta)) H = np.zeros((N, len(QMF))) G = np.zeros((N, len(QMF))) for k in range(N): - constant_factor = (2 * k + 1) * (np.pi / - (2 * N)) * (np.arange(taps + 1) - - ((taps - 1) / 2)) - phase = (-1)**k * np.pi / 4 + constant_factor = (2 * k + 1) * (np.pi / (2 * N)) * (np.arange(taps + 1) - ((taps - 1) / 2)) + phase = (-1) ** k * np.pi / 4 H[k] = 2 * QMF * np.cos(constant_factor + phase) G[k] = 2 * QMF * np.cos(constant_factor - phase) # [N, 1, taps + 1] == [filter_width, in_channels, out_channels] - self.H = np.transpose(H[:, None, :], (2, 1, 0)).astype('float32') - self.G = np.transpose(G[None, :, :], (2, 1, 0)).astype('float32') + self.H = np.transpose(H[:, None, :], (2, 1, 0)).astype("float32") + self.G = np.transpose(G[None, :, :], (2, 1, 0)).astype("float32") # filter for downsampling & upsampling updown_filter = np.zeros((N, N, N), dtype=np.float32) @@ -41,11 +39,8 @@ class PQMF(tf.keras.layers.Layer): """ x = tf.transpose(x, perm=[0, 2, 1]) x = tf.pad(x, [[0, 0], [self.taps // 2, self.taps // 2], [0, 0]], constant_values=0.0) - x = tf.nn.conv1d(x, self.H, stride=1, padding='VALID') - x = tf.nn.conv1d(x, - self.updown_filter, - stride=self.N, - padding='VALID') + x = tf.nn.conv1d(x, self.H, stride=1, padding="VALID") + x = tf.nn.conv1d(x, self.updown_filter, stride=self.N, padding="VALID") x = tf.transpose(x, perm=[0, 2, 1]) return x @@ -58,8 +53,8 @@ class PQMF(tf.keras.layers.Layer): x, self.updown_filter * self.N, strides=self.N, - output_shape=(tf.shape(x)[0], tf.shape(x)[1] * self.N, - self.N)) + output_shape=(tf.shape(x)[0], tf.shape(x)[1] * self.N, self.N), + ) x = tf.pad(x, [[0, 0], [self.taps // 2, self.taps // 2], [0, 0]], constant_values=0.0) x = tf.nn.conv1d(x, self.G, stride=1, padding="VALID") x = tf.transpose(x, perm=[0, 2, 1]) diff --git a/TTS/vocoder/tf/models/melgan_generator.py b/TTS/vocoder/tf/models/melgan_generator.py index 9a029df4..0a8a0b73 100644 --- a/TTS/vocoder/tf/models/melgan_generator.py +++ b/TTS/vocoder/tf/models/melgan_generator.py @@ -1,33 +1,35 @@ import logging import os -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # FATAL -logging.getLogger('tensorflow').setLevel(logging.FATAL) +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # FATAL +logging.getLogger("tensorflow").setLevel(logging.FATAL) import tensorflow as tf from TTS.vocoder.tf.layers.melgan import ResidualStack, ReflectionPad1d -#pylint: disable=too-many-ancestors -#pylint: disable=abstract-method +# pylint: disable=too-many-ancestors +# pylint: disable=abstract-method class MelganGenerator(tf.keras.models.Model): - """ Melgan Generator TF implementation dedicated for inference with no - weight norm """ - def __init__(self, - in_channels=80, - out_channels=1, - proj_kernel=7, - base_channels=512, - upsample_factors=(8, 8, 2, 2), - res_kernel=3, - num_res_blocks=3): - super(MelganGenerator, self).__init__() + """Melgan Generator TF implementation dedicated for inference with no + weight norm""" + + def __init__( + self, + in_channels=80, + out_channels=1, + proj_kernel=7, + base_channels=512, + upsample_factors=(8, 8, 2, 2), + res_kernel=3, + num_res_blocks=3, + ): + super().__init__() self.in_channels = in_channels # assert model parameters - assert (proj_kernel - - 1) % 2 == 0, " [!] proj_kernel should be an odd number." + assert (proj_kernel - 1) % 2 == 0, " [!] proj_kernel should be an odd number." # setup additional model parameters base_padding = (proj_kernel - 1) // 2 @@ -37,19 +39,16 @@ class MelganGenerator(tf.keras.models.Model): # initial layer self.initial_layer = [ ReflectionPad1d(base_padding), - tf.keras.layers.Conv2D(filters=base_channels, - kernel_size=(proj_kernel, 1), - strides=1, - padding='valid', - use_bias=True, - name="1") + tf.keras.layers.Conv2D( + filters=base_channels, kernel_size=(proj_kernel, 1), strides=1, padding="valid", use_bias=True, name="1" + ), ] num_layers = 3 # count number of layers for layer naming # upsampling layers and residual stacks self.upsample_layers = [] for idx, upsample_factor in enumerate(upsample_factors): - layer_out_channels = base_channels // (2**(idx + 1)) + layer_out_channels = base_channels // (2 ** (idx + 1)) layer_filter_size = upsample_factor * 2 layer_stride = upsample_factor # layer_output_padding = upsample_factor % 2 @@ -59,14 +58,17 @@ class MelganGenerator(tf.keras.models.Model): filters=layer_out_channels, kernel_size=(layer_filter_size, 1), strides=(layer_stride, 1), - padding='same', + padding="same", # output_padding=layer_output_padding, use_bias=True, - name=f'{num_layers}'), - ResidualStack(channels=layer_out_channels, - num_res_blocks=num_res_blocks, - kernel_size=res_kernel, - name=f'layers.{num_layers + 1}') + name=f"{num_layers}", + ), + ResidualStack( + channels=layer_out_channels, + num_res_blocks=num_res_blocks, + kernel_size=res_kernel, + name=f"layers.{num_layers + 1}", + ), ] num_layers += num_res_blocks - 1 @@ -75,11 +77,10 @@ class MelganGenerator(tf.keras.models.Model): # final layer self.final_layers = [ ReflectionPad1d(base_padding), - tf.keras.layers.Conv2D(filters=out_channels, - kernel_size=(proj_kernel, 1), - use_bias=True, - name=f'layers.{num_layers + 1}'), - tf.keras.layers.Activation("tanh") + tf.keras.layers.Conv2D( + filters=out_channels, kernel_size=(proj_kernel, 1), use_bias=True, name=f"layers.{num_layers + 1}" + ), + tf.keras.layers.Activation("tanh"), ] # self.model_layers = tf.keras.models.Sequential(self.initial_layer + self.upsample_layers + self.final_layers, name="layers") @@ -114,7 +115,8 @@ class MelganGenerator(tf.keras.models.Model): experimental_relax_shapes=True, input_signature=[ tf.TensorSpec([1, None, None], dtype=tf.float32), - ],) + ], + ) def inference_tflite(self, c): c = tf.transpose(c, perm=[0, 2, 1]) c = tf.expand_dims(c, 2) diff --git a/TTS/vocoder/tf/models/multiband_melgan_generator.py b/TTS/vocoder/tf/models/multiband_melgan_generator.py index bdd333ed..23836659 100644 --- a/TTS/vocoder/tf/models/multiband_melgan_generator.py +++ b/TTS/vocoder/tf/models/multiband_melgan_generator.py @@ -3,25 +3,28 @@ import tensorflow as tf from TTS.vocoder.tf.models.melgan_generator import MelganGenerator from TTS.vocoder.tf.layers.pqmf import PQMF -#pylint: disable=too-many-ancestors -#pylint: disable=abstract-method +# pylint: disable=too-many-ancestors +# pylint: disable=abstract-method class MultibandMelganGenerator(MelganGenerator): - def __init__(self, - in_channels=80, - out_channels=4, - proj_kernel=7, - base_channels=384, - upsample_factors=(2, 8, 2, 2), - res_kernel=3, - num_res_blocks=3): - super(MultibandMelganGenerator, - self).__init__(in_channels=in_channels, - out_channels=out_channels, - proj_kernel=proj_kernel, - base_channels=base_channels, - upsample_factors=upsample_factors, - res_kernel=res_kernel, - num_res_blocks=num_res_blocks) + def __init__( + self, + in_channels=80, + out_channels=4, + proj_kernel=7, + base_channels=384, + upsample_factors=(2, 8, 2, 2), + res_kernel=3, + num_res_blocks=3, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + proj_kernel=proj_kernel, + base_channels=base_channels, + upsample_factors=upsample_factors, + res_kernel=res_kernel, + num_res_blocks=num_res_blocks, + ) self.pqmf_layer = PQMF(N=4, taps=62, cutoff=0.15, beta=9.0) def pqmf_analysis(self, x): @@ -46,7 +49,8 @@ class MultibandMelganGenerator(MelganGenerator): experimental_relax_shapes=True, input_signature=[ tf.TensorSpec([1, 80, None], dtype=tf.float32), - ],) + ], + ) def inference_tflite(self, c): c = tf.transpose(c, perm=[0, 2, 1]) c = tf.expand_dims(c, 2) diff --git a/TTS/vocoder/tf/utils/convert_torch_to_tf_utils.py b/TTS/vocoder/tf/utils/convert_torch_to_tf_utils.py index 25139cc3..5e0427b1 100644 --- a/TTS/vocoder/tf/utils/convert_torch_to_tf_utils.py +++ b/TTS/vocoder/tf/utils/convert_torch_to_tf_utils.py @@ -10,14 +10,14 @@ def compare_torch_tf(torch_tensor, tf_tensor): def convert_tf_name(tf_name): """ Convert certain patterns in TF layer names to Torch patterns """ tf_name_tmp = tf_name - tf_name_tmp = tf_name_tmp.replace(':0', '') - tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_1/recurrent_kernel', '/weight_hh_l0') - tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_2/kernel', '/weight_ih_l1') - tf_name_tmp = tf_name_tmp.replace('/recurrent_kernel', '/weight_hh') - tf_name_tmp = tf_name_tmp.replace('/kernel', '/weight') - tf_name_tmp = tf_name_tmp.replace('/gamma', '/weight') - tf_name_tmp = tf_name_tmp.replace('/beta', '/bias') - tf_name_tmp = tf_name_tmp.replace('/', '.') + tf_name_tmp = tf_name_tmp.replace(":0", "") + tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_1/recurrent_kernel", "/weight_hh_l0") + tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_2/kernel", "/weight_ih_l1") + tf_name_tmp = tf_name_tmp.replace("/recurrent_kernel", "/weight_hh") + tf_name_tmp = tf_name_tmp.replace("/kernel", "/weight") + tf_name_tmp = tf_name_tmp.replace("/gamma", "/weight") + tf_name_tmp = tf_name_tmp.replace("/beta", "/bias") + tf_name_tmp = tf_name_tmp.replace("/", ".") return tf_name_tmp @@ -26,15 +26,17 @@ def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict): print(" > Passing weights from Torch to TF ...") for tf_var in tf_vars: torch_var_name = var_map_dict[tf_var.name] - print(f' | > {tf_var.name} <-- {torch_var_name}') + print(f" | > {tf_var.name} <-- {torch_var_name}") # if tuple, it is a bias variable - if 'kernel' in tf_var.name: + if "kernel" in tf_var.name: torch_weight = state_dict[torch_var_name] numpy_weight = torch_weight.permute([2, 1, 0]).numpy()[:, None, :, :] - if 'bias' in tf_var.name: + if "bias" in tf_var.name: torch_weight = state_dict[torch_var_name] numpy_weight = torch_weight - assert np.all(tf_var.shape == numpy_weight.shape), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}" + assert np.all( + tf_var.shape == numpy_weight.shape + ), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}" tf.keras.backend.set_value(tf_var, numpy_weight) return tf_vars diff --git a/TTS/vocoder/tf/utils/generic_utils.py b/TTS/vocoder/tf/utils/generic_utils.py index 0daf2d6e..97cb9ae7 100644 --- a/TTS/vocoder/tf/utils/generic_utils.py +++ b/TTS/vocoder/tf/utils/generic_utils.py @@ -4,32 +4,33 @@ import importlib def to_camel(text): text = text.capitalize() - return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) + return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) def setup_generator(c): print(" > Generator Model: {}".format(c.generator_model)) - MyModel = importlib.import_module('TTS.vocoder.tf.models.' + - c.generator_model.lower()) + MyModel = importlib.import_module("TTS.vocoder.tf.models." + c.generator_model.lower()) MyModel = getattr(MyModel, to_camel(c.generator_model)) - if c.generator_model in 'melgan_generator': + if c.generator_model in "melgan_generator": model = MyModel( - in_channels=c.audio['num_mels'], + in_channels=c.audio["num_mels"], out_channels=1, proj_kernel=7, base_channels=512, - upsample_factors=c.generator_model_params['upsample_factors'], + upsample_factors=c.generator_model_params["upsample_factors"], res_kernel=3, - num_res_blocks=c.generator_model_params['num_res_blocks']) - if c.generator_model in 'melgan_fb_generator': + num_res_blocks=c.generator_model_params["num_res_blocks"], + ) + if c.generator_model in "melgan_fb_generator": pass - if c.generator_model in 'multiband_melgan_generator': + if c.generator_model in "multiband_melgan_generator": model = MyModel( - in_channels=c.audio['num_mels'], + in_channels=c.audio["num_mels"], out_channels=4, proj_kernel=7, base_channels=384, - upsample_factors=c.generator_model_params['upsample_factors'], + upsample_factors=c.generator_model_params["upsample_factors"], res_kernel=3, - num_res_blocks=c.generator_model_params['num_res_blocks']) + num_res_blocks=c.generator_model_params["num_res_blocks"], + ) return model diff --git a/TTS/vocoder/tf/utils/io.py b/TTS/vocoder/tf/utils/io.py index c73c9cd8..6ffa302c 100644 --- a/TTS/vocoder/tf/utils/io.py +++ b/TTS/vocoder/tf/utils/io.py @@ -6,19 +6,19 @@ import tensorflow as tf def save_checkpoint(model, current_step, epoch, output_path, **kwargs): """ Save TF Vocoder model """ state = { - 'model': model.weights, - 'step': current_step, - 'epoch': epoch, - 'date': datetime.date.today().strftime("%B %d, %Y"), + "model": model.weights, + "step": current_step, + "epoch": epoch, + "date": datetime.date.today().strftime("%B %d, %Y"), } state.update(kwargs) - pickle.dump(state, open(output_path, 'wb')) + pickle.dump(state, open(output_path, "wb")) def load_checkpoint(model, checkpoint_path): """ Load TF Vocoder model """ - checkpoint = pickle.load(open(checkpoint_path, 'rb')) - chkp_var_dict = {var.name: var.numpy() for var in checkpoint['model']} + checkpoint = pickle.load(open(checkpoint_path, "rb")) + chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]} tf_vars = model.weights for tf_var in tf_vars: layer_name = tf_var.name diff --git a/TTS/vocoder/tf/utils/tflite.py b/TTS/vocoder/tf/utils/tflite.py index d62a081a..e0c630b9 100644 --- a/TTS/vocoder/tf/utils/tflite.py +++ b/TTS/vocoder/tf/utils/tflite.py @@ -1,25 +1,20 @@ import tensorflow as tf -def convert_melgan_to_tflite(model, - output_path=None, - experimental_converter=True): +def convert_melgan_to_tflite(model, output_path=None, experimental_converter=True): """Convert Tensorflow MelGAN model to TFLite. Save a binary file if output_path is provided, else return TFLite model.""" concrete_function = model.inference_tflite.get_concrete_function() - converter = tf.lite.TFLiteConverter.from_concrete_functions( - [concrete_function]) + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function]) converter.experimental_new_converter = experimental_converter converter.optimizations = [] - converter.target_spec.supported_ops = [ - tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS - ] + converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] tflite_model = converter.convert() - print(f'Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.') + print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.") if output_path is not None: # same model binary if outputpath is provided - with open(output_path, 'wb') as f: + with open(output_path, "wb") as f: f.write(tflite_model) return None return tflite_model diff --git a/TTS/vocoder/utils/distribution.py b/TTS/vocoder/utils/distribution.py index b0553ed0..3f62b7ad 100644 --- a/TTS/vocoder/utils/distribution.py +++ b/TTS/vocoder/utils/distribution.py @@ -11,11 +11,7 @@ def gaussian_loss(y_hat, y, log_std_min=-7.0): mean = y_hat[:, :, :1] log_std = torch.clamp(y_hat[:, :, 1:], min=log_std_min) # TODO: replace with pytorch dist - log_probs = -0.5 * ( - -math.log(2.0 * math.pi) - - 2.0 * log_std - - torch.pow(y - mean, 2) * torch.exp((-2.0 * log_std)) - ) + log_probs = -0.5 * (-math.log(2.0 * math.pi) - 2.0 * log_std - torch.pow(y - mean, 2) * torch.exp((-2.0 * log_std))) return log_probs.squeeze().mean() @@ -28,8 +24,7 @@ def sample_from_gaussian(y_hat, log_std_min=-7.0, scale_factor=1.0): torch.exp(log_std), ) sample = dist.sample() - sample = torch.clamp(torch.clamp( - sample, min=-scale_factor), max=scale_factor) + sample = torch.clamp(torch.clamp(sample, min=-scale_factor), max=scale_factor) del dist return sample @@ -44,11 +39,7 @@ def log_sum_exp(x): # It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py -def discretized_mix_logistic_loss(y_hat, - y, - num_classes=65536, - log_scale_min=None, - reduce=True): +def discretized_mix_logistic_loss(y_hat, y, num_classes=65536, log_scale_min=None, reduce=True): if log_scale_min is None: log_scale_min = float(np.log(1e-14)) y_hat = y_hat.permute(0, 2, 1) @@ -61,9 +52,8 @@ def discretized_mix_logistic_loss(y_hat, # unpack parameters. (B, T, num_mixtures) x 3 logit_probs = y_hat[:, :, :nr_mix] - means = y_hat[:, :, nr_mix: 2 * nr_mix] - log_scales = torch.clamp( - y_hat[:, :, 2 * nr_mix: 3 * nr_mix], min=log_scale_min) + means = y_hat[:, :, nr_mix : 2 * nr_mix] + log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min) # B x T x 1 -> B x T x num_mixtures y = y.expand_as(means) @@ -103,14 +93,11 @@ def discretized_mix_logistic_loss(y_hat, # for num_classes=65536 case? 1e-7? not sure.. inner_inner_cond = (cdf_delta > 1e-5).float() - inner_inner_out = inner_inner_cond * torch.log( - torch.clamp(cdf_delta, min=1e-12) - ) + (1.0 - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2)) - inner_cond = (y > 0.999).float() - inner_out = ( - inner_cond * log_one_minus_cdf_min + - (1.0 - inner_cond) * inner_inner_out + inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1.0 - inner_inner_cond) * ( + log_pdf_mid - np.log((num_classes - 1) / 2) ) + inner_cond = (y > 0.999).float() + inner_out = inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out cond = (y < -0.999).float() log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out @@ -147,10 +134,8 @@ def sample_from_discretized_mix_logistic(y, log_scale_min=None): # (B, T) -> (B, T, nr_mix) one_hot = to_one_hot(argmax, nr_mix) # select logistic parameters - means = torch.sum(y[:, :, nr_mix: 2 * nr_mix] * one_hot, dim=-1) - log_scales = torch.clamp( - torch.sum(y[:, :, 2 * nr_mix: 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min - ) + means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1) + log_scales = torch.clamp(torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1), min=log_scale_min) # sample from logistic & clip to interval # we don't actually round to the nearest 8bit value when sampling u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5) diff --git a/TTS/vocoder/utils/generic_utils.py b/TTS/vocoder/utils/generic_utils.py index 0d532063..35102295 100644 --- a/TTS/vocoder/utils/generic_utils.py +++ b/TTS/vocoder/utils/generic_utils.py @@ -21,11 +21,9 @@ def interpolate_vocoder_input(scale_factor, spec): """ print(" > before interpolation :", spec.shape) spec = torch.tensor(spec).unsqueeze(0).unsqueeze(0) # pylint: disable=not-callable - spec = torch.nn.functional.interpolate(spec, - scale_factor=scale_factor, - recompute_scale_factor=True, - mode='bilinear', - align_corners=False).squeeze(0) + spec = torch.nn.functional.interpolate( + spec, scale_factor=scale_factor, recompute_scale_factor=True, mode="bilinear", align_corners=False + ).squeeze(0) print(" > after interpolation :", spec.shape) return spec @@ -63,132 +61,131 @@ def plot_results(y_hat, y, ap, global_step, name_prefix): def to_camel(text): text = text.capitalize() - return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text) + return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text) def setup_generator(c): print(" > Generator Model: {}".format(c.generator_model)) - MyModel = importlib.import_module('TTS.vocoder.models.' + - c.generator_model.lower()) + MyModel = importlib.import_module("TTS.vocoder.models." + c.generator_model.lower()) # this is to preserve the WaveRNN class name (instead of Wavernn) - if c.generator_model.lower() == 'wavernn': - MyModel = getattr(MyModel, 'WaveRNN') + if c.generator_model.lower() == "wavernn": + MyModel = getattr(MyModel, "WaveRNN") else: MyModel = getattr(MyModel, to_camel(c.generator_model)) - if c.generator_model.lower() in 'wavernn': + if c.generator_model.lower() in "wavernn": model = MyModel( - rnn_dims=c.wavernn_model_params['rnn_dims'], - fc_dims=c.wavernn_model_params['fc_dims'], + rnn_dims=c.wavernn_model_params["rnn_dims"], + fc_dims=c.wavernn_model_params["fc_dims"], mode=c.mode, mulaw=c.mulaw, pad=c.padding, - use_aux_net=c.wavernn_model_params['use_aux_net'], - use_upsample_net=c.wavernn_model_params['use_upsample_net'], - upsample_factors=c.wavernn_model_params['upsample_factors'], - feat_dims=c.audio['num_mels'], - compute_dims=c.wavernn_model_params['compute_dims'], - res_out_dims=c.wavernn_model_params['res_out_dims'], - num_res_blocks=c.wavernn_model_params['num_res_blocks'], + use_aux_net=c.wavernn_model_params["use_aux_net"], + use_upsample_net=c.wavernn_model_params["use_upsample_net"], + upsample_factors=c.wavernn_model_params["upsample_factors"], + feat_dims=c.audio["num_mels"], + compute_dims=c.wavernn_model_params["compute_dims"], + res_out_dims=c.wavernn_model_params["res_out_dims"], + num_res_blocks=c.wavernn_model_params["num_res_blocks"], hop_length=c.audio["hop_length"], - sample_rate=c.audio["sample_rate"],) - elif c.generator_model.lower() in 'melgan_generator': + sample_rate=c.audio["sample_rate"], + ) + elif c.generator_model.lower() in "melgan_generator": model = MyModel( - in_channels=c.audio['num_mels'], + in_channels=c.audio["num_mels"], out_channels=1, proj_kernel=7, base_channels=512, - upsample_factors=c.generator_model_params['upsample_factors'], + upsample_factors=c.generator_model_params["upsample_factors"], res_kernel=3, - num_res_blocks=c.generator_model_params['num_res_blocks']) - elif c.generator_model in 'melgan_fb_generator': - raise ValueError( - 'melgan_fb_generator is now fullband_melgan_generator') - elif c.generator_model.lower() in 'multiband_melgan_generator': + num_res_blocks=c.generator_model_params["num_res_blocks"], + ) + elif c.generator_model in "melgan_fb_generator": + raise ValueError("melgan_fb_generator is now fullband_melgan_generator") + elif c.generator_model.lower() in "multiband_melgan_generator": model = MyModel( - in_channels=c.audio['num_mels'], + in_channels=c.audio["num_mels"], out_channels=4, proj_kernel=7, base_channels=384, - upsample_factors=c.generator_model_params['upsample_factors'], + upsample_factors=c.generator_model_params["upsample_factors"], res_kernel=3, - num_res_blocks=c.generator_model_params['num_res_blocks']) - elif c.generator_model.lower() in 'fullband_melgan_generator': + num_res_blocks=c.generator_model_params["num_res_blocks"], + ) + elif c.generator_model.lower() in "fullband_melgan_generator": model = MyModel( - in_channels=c.audio['num_mels'], + in_channels=c.audio["num_mels"], out_channels=1, proj_kernel=7, base_channels=512, - upsample_factors=c.generator_model_params['upsample_factors'], + upsample_factors=c.generator_model_params["upsample_factors"], res_kernel=3, - num_res_blocks=c.generator_model_params['num_res_blocks']) - elif c.generator_model.lower() in 'parallel_wavegan_generator': + num_res_blocks=c.generator_model_params["num_res_blocks"], + ) + elif c.generator_model.lower() in "parallel_wavegan_generator": model = MyModel( in_channels=1, out_channels=1, kernel_size=3, - num_res_blocks=c.generator_model_params['num_res_blocks'], - stacks=c.generator_model_params['stacks'], + num_res_blocks=c.generator_model_params["num_res_blocks"], + stacks=c.generator_model_params["stacks"], res_channels=64, gate_channels=128, skip_channels=64, - aux_channels=c.audio['num_mels'], + aux_channels=c.audio["num_mels"], dropout=0.0, bias=True, use_weight_norm=True, - upsample_factors=c.generator_model_params['upsample_factors']) - elif c.generator_model.lower() in 'wavegrad': + upsample_factors=c.generator_model_params["upsample_factors"], + ) + elif c.generator_model.lower() in "wavegrad": model = MyModel( - in_channels=c['audio']['num_mels'], + in_channels=c["audio"]["num_mels"], out_channels=1, - use_weight_norm=c['model_params']['use_weight_norm'], - x_conv_channels=c['model_params']['x_conv_channels'], - y_conv_channels=c['model_params']['y_conv_channels'], - dblock_out_channels=c['model_params']['dblock_out_channels'], - ublock_out_channels=c['model_params']['ublock_out_channels'], - upsample_factors=c['model_params']['upsample_factors'], - upsample_dilations=c['model_params']['upsample_dilations']) + use_weight_norm=c["model_params"]["use_weight_norm"], + x_conv_channels=c["model_params"]["x_conv_channels"], + y_conv_channels=c["model_params"]["y_conv_channels"], + dblock_out_channels=c["model_params"]["dblock_out_channels"], + ublock_out_channels=c["model_params"]["ublock_out_channels"], + upsample_factors=c["model_params"]["upsample_factors"], + upsample_dilations=c["model_params"]["upsample_dilations"], + ) else: - raise NotImplementedError( - f'Model {c.generator_model} not implemented!') + raise NotImplementedError(f"Model {c.generator_model} not implemented!") return model def setup_discriminator(c): print(" > Discriminator Model: {}".format(c.discriminator_model)) - if 'parallel_wavegan' in c.discriminator_model: - MyModel = importlib.import_module( - 'TTS.vocoder.models.parallel_wavegan_discriminator') + if "parallel_wavegan" in c.discriminator_model: + MyModel = importlib.import_module("TTS.vocoder.models.parallel_wavegan_discriminator") else: - MyModel = importlib.import_module('TTS.vocoder.models.' + - c.discriminator_model.lower()) + MyModel = importlib.import_module("TTS.vocoder.models." + c.discriminator_model.lower()) MyModel = getattr(MyModel, to_camel(c.discriminator_model.lower())) - if c.discriminator_model in 'random_window_discriminator': + if c.discriminator_model in "random_window_discriminator": model = MyModel( - cond_channels=c.audio['num_mels'], - hop_length=c.audio['hop_length'], - uncond_disc_donwsample_factors=c. - discriminator_model_params['uncond_disc_donwsample_factors'], - cond_disc_downsample_factors=c. - discriminator_model_params['cond_disc_downsample_factors'], - cond_disc_out_channels=c. - discriminator_model_params['cond_disc_out_channels'], - window_sizes=c.discriminator_model_params['window_sizes']) - if c.discriminator_model in 'melgan_multiscale_discriminator': + cond_channels=c.audio["num_mels"], + hop_length=c.audio["hop_length"], + uncond_disc_donwsample_factors=c.discriminator_model_params["uncond_disc_donwsample_factors"], + cond_disc_downsample_factors=c.discriminator_model_params["cond_disc_downsample_factors"], + cond_disc_out_channels=c.discriminator_model_params["cond_disc_out_channels"], + window_sizes=c.discriminator_model_params["window_sizes"], + ) + if c.discriminator_model in "melgan_multiscale_discriminator": model = MyModel( in_channels=1, out_channels=1, kernel_sizes=(5, 3), - base_channels=c.discriminator_model_params['base_channels'], - max_channels=c.discriminator_model_params['max_channels'], - downsample_factors=c. - discriminator_model_params['downsample_factors']) - if c.discriminator_model == 'residual_parallel_wavegan_discriminator': + base_channels=c.discriminator_model_params["base_channels"], + max_channels=c.discriminator_model_params["max_channels"], + downsample_factors=c.discriminator_model_params["downsample_factors"], + ) + if c.discriminator_model == "residual_parallel_wavegan_discriminator": model = MyModel( in_channels=1, out_channels=1, kernel_size=3, - num_layers=c.discriminator_model_params['num_layers'], - stacks=c.discriminator_model_params['stacks'], + num_layers=c.discriminator_model_params["num_layers"], + stacks=c.discriminator_model_params["stacks"], res_channels=64, gate_channels=128, skip_channels=64, @@ -197,17 +194,17 @@ def setup_discriminator(c): nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.2}, ) - if c.discriminator_model == 'parallel_wavegan_discriminator': + if c.discriminator_model == "parallel_wavegan_discriminator": model = MyModel( in_channels=1, out_channels=1, kernel_size=3, - num_layers=c.discriminator_model_params['num_layers'], + num_layers=c.discriminator_model_params["num_layers"], conv_channels=64, dilation_factor=1, nonlinear_activation="LeakyReLU", nonlinear_activation_params={"negative_slope": 0.2}, - bias=True + bias=True, ) return model diff --git a/TTS/vocoder/utils/io.py b/TTS/vocoder/utils/io.py index f3bc9bad..8a5d144d 100644 --- a/TTS/vocoder/utils/io.py +++ b/TTS/vocoder/utils/io.py @@ -9,11 +9,11 @@ from TTS.utils.io import RenamingUnpickler def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pylint: disable=redefined-builtin try: - state = torch.load(checkpoint_path, map_location=torch.device('cpu')) + state = torch.load(checkpoint_path, map_location=torch.device("cpu")) except ModuleNotFoundError: pickle_tts.Unpickler = RenamingUnpickler - state = torch.load(checkpoint_path, map_location=torch.device('cpu'), pickle_module=pickle_tts) - model.load_state_dict(state['model']) + state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts) + model.load_state_dict(state["model"]) if use_cuda: model.cuda() if eval: @@ -21,76 +21,104 @@ def load_checkpoint(model, checkpoint_path, use_cuda=False, eval=False): # pyli return model, state -def save_model(model, optimizer, scheduler, model_disc, optimizer_disc, - scheduler_disc, current_step, epoch, output_path, **kwargs): - if hasattr(model, 'module'): +def save_model( + model, optimizer, scheduler, model_disc, optimizer_disc, scheduler_disc, current_step, epoch, output_path, **kwargs +): + if hasattr(model, "module"): model_state = model.module.state_dict() else: model_state = model.state_dict() - model_disc_state = model_disc.state_dict()\ - if model_disc is not None else None - optimizer_state = optimizer.state_dict()\ - if optimizer is not None else None - optimizer_disc_state = optimizer_disc.state_dict()\ - if optimizer_disc is not None else None - scheduler_state = scheduler.state_dict()\ - if scheduler is not None else None - scheduler_disc_state = scheduler_disc.state_dict()\ - if scheduler_disc is not None else None + model_disc_state = model_disc.state_dict() if model_disc is not None else None + optimizer_state = optimizer.state_dict() if optimizer is not None else None + optimizer_disc_state = optimizer_disc.state_dict() if optimizer_disc is not None else None + scheduler_state = scheduler.state_dict() if scheduler is not None else None + scheduler_disc_state = scheduler_disc.state_dict() if scheduler_disc is not None else None state = { - 'model': model_state, - 'optimizer': optimizer_state, - 'scheduler': scheduler_state, - 'model_disc': model_disc_state, - 'optimizer_disc': optimizer_disc_state, - 'scheduler_disc': scheduler_disc_state, - 'step': current_step, - 'epoch': epoch, - 'date': datetime.date.today().strftime("%B %d, %Y"), + "model": model_state, + "optimizer": optimizer_state, + "scheduler": scheduler_state, + "model_disc": model_disc_state, + "optimizer_disc": optimizer_disc_state, + "scheduler_disc": scheduler_disc_state, + "step": current_step, + "epoch": epoch, + "date": datetime.date.today().strftime("%B %d, %Y"), } state.update(kwargs) torch.save(state, output_path) -def save_checkpoint(model, optimizer, scheduler, model_disc, optimizer_disc, - scheduler_disc, current_step, epoch, output_folder, - **kwargs): - file_name = 'checkpoint_{}.pth.tar'.format(current_step) +def save_checkpoint( + model, + optimizer, + scheduler, + model_disc, + optimizer_disc, + scheduler_disc, + current_step, + epoch, + output_folder, + **kwargs, +): + file_name = "checkpoint_{}.pth.tar".format(current_step) checkpoint_path = os.path.join(output_folder, file_name) print(" > CHECKPOINT : {}".format(checkpoint_path)) - save_model(model, optimizer, scheduler, model_disc, optimizer_disc, - scheduler_disc, current_step, epoch, checkpoint_path, **kwargs) + save_model( + model, + optimizer, + scheduler, + model_disc, + optimizer_disc, + scheduler_disc, + current_step, + epoch, + checkpoint_path, + **kwargs, + ) -def save_best_model(current_loss, best_loss, model, optimizer, scheduler, - model_disc, optimizer_disc, scheduler_disc, current_step, - epoch, out_path, keep_all_best=False, keep_after=10000, - **kwargs): +def save_best_model( + current_loss, + best_loss, + model, + optimizer, + scheduler, + model_disc, + optimizer_disc, + scheduler_disc, + current_step, + epoch, + out_path, + keep_all_best=False, + keep_after=10000, + **kwargs, +): if current_loss < best_loss: - best_model_name = f'best_model_{current_step}.pth.tar' + best_model_name = f"best_model_{current_step}.pth.tar" checkpoint_path = os.path.join(out_path, best_model_name) print(" > BEST MODEL : {}".format(checkpoint_path)) - save_model(model, - optimizer, - scheduler, - model_disc, - optimizer_disc, - scheduler_disc, - current_step, - epoch, - checkpoint_path, - model_loss=current_loss, - **kwargs) + save_model( + model, + optimizer, + scheduler, + model_disc, + optimizer_disc, + scheduler_disc, + current_step, + epoch, + checkpoint_path, + model_loss=current_loss, + **kwargs, + ) # only delete previous if current is saved successfully if not keep_all_best or (current_step < keep_after): - model_names = glob.glob( - os.path.join(out_path, 'best_model*.pth.tar')) + model_names = glob.glob(os.path.join(out_path, "best_model*.pth.tar")) for model_name in model_names: if os.path.basename(model_name) == best_model_name: continue os.remove(model_name) # create symlink to best model for convinience - link_name = 'best_model.pth.tar' + link_name = "best_model.pth.tar" link_path = os.path.join(out_path, link_name) if os.path.islink(link_path) or os.path.isfile(link_path): os.remove(link_path) diff --git a/pyproject.toml b/pyproject.toml index b6c632d8..335303d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,28 @@ [build-system] requires = ["setuptools", "wheel", "Cython", "numpy==1.17.5"] + +[flake8] +max-line-length=120 + +[tool.black] +line-length = 120 +target-version = ['py38'] +exclude = ''' + +( + /( + \.eggs # exclude a few common directories in the + | \.git # root of the project + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist + )/ + | foo.py # also separately exclude a file named foo.py in + # the root of the project +) +''' diff --git a/tests/test_audio.py b/tests/test_audio.py index c00cd8f8..e33a6d6c 100644 --- a/tests/test_audio.py +++ b/tests/test_audio.py @@ -17,7 +17,7 @@ conf = load_config(os.path.join(get_tests_input_path(), 'test_config.json')) # pylint: disable=protected-access class TestAudio(unittest.TestCase): def __init__(self, *args, **kwargs): - super(TestAudio, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.ap = AudioProcessor(**conf.audio) def test_audio_synthesis(self): diff --git a/tests/test_loader.py b/tests/test_loader.py index b79aad19..439e8d35 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -28,7 +28,7 @@ print(" > Dynamic data loader test: {}".format(DATA_EXIST)) class TestTTSDataset(unittest.TestCase): def __init__(self, *args, **kwargs): - super(TestTTSDataset, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) self.max_loader_iter = 4 self.ap = AudioProcessor(**c.audio) diff --git a/tests/test_vocoder_losses.py b/tests/test_vocoder_losses.py index d578a130..4a18ee53 100644 --- a/tests/test_vocoder_losses.py +++ b/tests/test_vocoder_losses.py @@ -22,7 +22,7 @@ def test_torch_stft(): torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length) # librosa stft wav = ap.load_wav(WAV_FILE) - M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access + M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access # torch stft wav = torch.from_numpy(wav[None, :]).float() M_torch = torch_stft(wav) @@ -42,9 +42,10 @@ def test_stft_loss(): def test_multiscale_stft_loss(): - stft_loss = MultiScaleSTFTLoss([ap.fft_size//2, ap.fft_size, ap.fft_size*2], - [ap.hop_length // 2, ap.hop_length, ap.hop_length * 2], - [ap.win_length // 2, ap.win_length, ap.win_length * 2]) + stft_loss = MultiScaleSTFTLoss( + [ap.fft_size // 2, ap.fft_size, ap.fft_size * 2], + [ap.hop_length // 2, ap.hop_length, ap.hop_length * 2], + [ap.win_length // 2, ap.win_length, ap.win_length * 2]) wav = ap.load_wav(WAV_FILE) wav = torch.from_numpy(wav[None, :]).float() loss_m, loss_sc = stft_loss(wav, wav) @@ -53,6 +54,7 @@ def test_multiscale_stft_loss(): assert loss_sc < 1.0 assert loss_m + loss_sc > 0 + def test_melgan_feature_loss(): feats_real = [] feats_fake = [] @@ -71,7 +73,6 @@ def test_melgan_feature_loss(): loss = loss_func(feats_fake, feats_real) assert loss.item() <= 1.0 - feats_real = [] feats_fake = [] @@ -89,4 +90,3 @@ def test_melgan_feature_loss(): loss_func = MelganFeatureLoss() loss = loss_func(feats_fake, feats_real) assert loss.item() == 0 -