format with black and pylint 2.7.3

pull/423/head
Eren Gölge 2021-04-09 00:38:08 +02:00
parent 5de7eb708b
commit 0e79fa86ad
127 changed files with 5511 additions and 5491 deletions

View File

@ -15,16 +15,14 @@ from TTS.utils.audio import AudioProcessor
from TTS.utils.io import load_config
if __name__ == '__main__':
# pylint: disable=bad-continuation
if __name__ == "__main__":
# pylint: disable=bad-option-value
parser = argparse.ArgumentParser(
description='''Extract attention masks from trained Tacotron/Tacotron2 models.
These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n'''
'''Each attention mask is written to the same path as the input wav file with ".npy" file extension.
(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n'''
'''
description="""Extract attention masks from trained Tacotron/Tacotron2 models.
These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n"""
"""Each attention mask is written to the same path as the input wav file with ".npy" file extension.
(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n"""
"""
Example run:
CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py
--model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar
@ -34,53 +32,44 @@ Example run:
--batch_size 32
--dataset ljspeech
--use_cuda True
''',
formatter_class=RawTextHelpFormatter
)
parser.add_argument('--model_path',
type=str,
required=True,
help='Path to Tacotron/Tacotron2 model file ')
parser.add_argument(
'--config_path',
type=str,
required=True,
help='Path to Tacotron/Tacotron2 config file.',
""",
formatter_class=RawTextHelpFormatter,
)
parser.add_argument('--dataset',
type=str,
default='',
required=True,
help='Target dataset processor name from TTS.tts.dataset.preprocess.')
parser.add_argument("--model_path", type=str, required=True, help="Path to Tacotron/Tacotron2 model file ")
parser.add_argument(
'--dataset_metafile',
"--config_path",
type=str,
default='',
required=True,
help='Dataset metafile inclusing file paths with transcripts.')
help="Path to Tacotron/Tacotron2 config file.",
)
parser.add_argument(
'--data_path',
"--dataset",
type=str,
default='',
help='Defines the data path. It overwrites config.json.')
parser.add_argument('--use_cuda',
type=bool,
default=False,
help="enable/disable cuda.")
default="",
required=True,
help="Target dataset processor name from TTS.tts.dataset.preprocess.",
)
parser.add_argument(
'--batch_size',
default=16,
type=int,
help='Batch size for the model. Use batch_size=1 if you have no CUDA.')
"--dataset_metafile",
type=str,
default="",
required=True,
help="Dataset metafile inclusing file paths with transcripts.",
)
parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.")
parser.add_argument("--use_cuda", type=bool, default=False, help="enable/disable cuda.")
parser.add_argument(
"--batch_size", default=16, type=int, help="Batch size for the model. Use batch_size=1 if you have no CUDA."
)
args = parser.parse_args()
C = load_config(args.config_path)
ap = AudioProcessor(**C.audio)
# if the vocabulary was passed, replace the default
if 'characters' in C.keys():
if "characters" in C.keys():
symbols, phonemes = make_symbols(**C.characters)
# load the model
@ -91,28 +80,32 @@ Example run:
model.eval()
# data loader
preprocessor = importlib.import_module('TTS.tts.datasets.preprocess')
preprocessor = importlib.import_module("TTS.tts.datasets.preprocess")
preprocessor = getattr(preprocessor, args.dataset)
meta_data = preprocessor(args.data_path, args.dataset_metafile)
dataset = MyDataset(model.decoder.r,
C.text_cleaner,
compute_linear_spec=False,
ap=ap,
meta_data=meta_data,
tp=C.characters if 'characters' in C.keys() else None,
add_blank=C['add_blank'] if 'add_blank' in C.keys() else False,
use_phonemes=C.use_phonemes,
phoneme_cache_path=C.phoneme_cache_path,
phoneme_language=C.phoneme_language,
enable_eos_bos=C.enable_eos_bos_chars)
dataset = MyDataset(
model.decoder.r,
C.text_cleaner,
compute_linear_spec=False,
ap=ap,
meta_data=meta_data,
tp=C.characters if "characters" in C.keys() else None,
add_blank=C["add_blank"] if "add_blank" in C.keys() else False,
use_phonemes=C.use_phonemes,
phoneme_cache_path=C.phoneme_cache_path,
phoneme_language=C.phoneme_language,
enable_eos_bos=C.enable_eos_bos_chars,
)
dataset.sort_items()
loader = DataLoader(dataset,
batch_size=args.batch_size,
num_workers=4,
collate_fn=dataset.collate_fn,
shuffle=False,
drop_last=False)
loader = DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=4,
collate_fn=dataset.collate_fn,
shuffle=False,
drop_last=False,
)
# compute attentions
file_paths = []
@ -134,25 +127,29 @@ Example run:
mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda()
mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(
text_input, text_lengths, mel_input)
mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)
alignments = alignments.detach()
for idx, alignment in enumerate(alignments):
item_idx = item_idxs[idx]
# interpolate if r > 1
alignment = torch.nn.functional.interpolate(
alignment.transpose(0, 1).unsqueeze(0),
size=None,
scale_factor=model.decoder.r,
mode='nearest',
align_corners=None,
recompute_scale_factor=None).squeeze(0).transpose(0, 1)
alignment = (
torch.nn.functional.interpolate(
alignment.transpose(0, 1).unsqueeze(0),
size=None,
scale_factor=model.decoder.r,
mode="nearest",
align_corners=None,
recompute_scale_factor=None,
)
.squeeze(0)
.transpose(0, 1)
)
# remove paddings
alignment = alignment[:mel_lengths[idx], :text_lengths[idx]].cpu().numpy()
alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy()
# set file paths
wav_file_name = os.path.basename(item_idx)
align_file_name = os.path.splitext(wav_file_name)[0] + '.npy'
align_file_name = os.path.splitext(wav_file_name)[0] + ".npy"
file_path = item_idx.replace(wav_file_name, align_file_name)
# save output
file_paths.append([item_idx, file_path])

View File

@ -13,91 +13,72 @@ from TTS.tts.utils.speakers import save_speaker_mapping
from TTS.tts.datasets.preprocess import load_meta_data
parser = argparse.ArgumentParser(
description='Compute embedding vectors for each wav file in a dataset. If "target_dataset" is defined, it generates "speakers.json" necessary for training a multi-speaker model.')
parser.add_argument(
'model_path',
type=str,
help='Path to model outputs (checkpoint, tensorboard etc.).')
parser.add_argument(
'config_path',
type=str,
help='Path to config file for training.',
description='Compute embedding vectors for each wav file in a dataset. If "target_dataset" is defined, it generates "speakers.json" necessary for training a multi-speaker model.'
)
parser.add_argument("model_path", type=str, help="Path to model outputs (checkpoint, tensorboard etc.).")
parser.add_argument(
'data_path',
"config_path",
type=str,
help='Data path for wav files - directory or CSV file')
help="Path to config file for training.",
)
parser.add_argument("data_path", type=str, help="Data path for wav files - directory or CSV file")
parser.add_argument("output_path", type=str, help="path for training outputs.")
parser.add_argument(
'output_path',
"--target_dataset",
type=str,
help='path for training outputs.')
parser.add_argument(
'--target_dataset',
type=str,
default='',
help='Target dataset to pick a processor from TTS.tts.dataset.preprocess. Necessary to create a speakers.json file.'
)
parser.add_argument(
'--use_cuda', type=bool, help='flag to set cuda.', default=False
)
parser.add_argument(
'--separator', type=str, help='Separator used in file if CSV is passed for data_path', default='|'
default="",
help="Target dataset to pick a processor from TTS.tts.dataset.preprocess. Necessary to create a speakers.json file.",
)
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=False)
parser.add_argument("--separator", type=str, help="Separator used in file if CSV is passed for data_path", default="|")
args = parser.parse_args()
c = load_config(args.config_path)
ap = AudioProcessor(**c['audio'])
ap = AudioProcessor(**c["audio"])
data_path = args.data_path
split_ext = os.path.splitext(data_path)
sep = args.separator
if args.target_dataset != '':
if args.target_dataset != "":
# if target dataset is defined
dataset_config = [
{
"name": args.target_dataset,
"path": args.data_path,
"meta_file_train": None,
"meta_file_val": None
},
{"name": args.target_dataset, "path": args.data_path, "meta_file_train": None, "meta_file_val": None},
]
wav_files, _ = load_meta_data(dataset_config, eval_split=False)
output_files = [wav_file[1].replace(data_path, args.output_path).replace(
'.wav', '.npy') for wav_file in wav_files]
output_files = [wav_file[1].replace(data_path, args.output_path).replace(".wav", ".npy") for wav_file in wav_files]
else:
# if target dataset is not defined
if len(split_ext) > 0 and split_ext[1].lower() == '.csv':
if len(split_ext) > 0 and split_ext[1].lower() == ".csv":
# Parse CSV
print(f'CSV file: {data_path}')
print(f"CSV file: {data_path}")
with open(data_path) as f:
wav_path = os.path.join(os.path.dirname(data_path), 'wavs')
wav_path = os.path.join(os.path.dirname(data_path), "wavs")
wav_files = []
print(f'Separator is: {sep}')
print(f"Separator is: {sep}")
for line in f:
components = line.split(sep)
if len(components) != 2:
print("Invalid line")
continue
wav_file = os.path.join(wav_path, components[0] + '.wav')
#print(f'wav_file: {wav_file}')
wav_file = os.path.join(wav_path, components[0] + ".wav")
# print(f'wav_file: {wav_file}')
if os.path.exists(wav_file):
wav_files.append(wav_file)
print(f'Count of wavs imported: {len(wav_files)}')
print(f"Count of wavs imported: {len(wav_files)}")
else:
# Parse all wav files in data_path
wav_files = glob.glob(data_path + '/**/*.wav', recursive=True)
wav_files = glob.glob(data_path + "/**/*.wav", recursive=True)
output_files = [wav_file.replace(data_path, args.output_path).replace(
'.wav', '.npy') for wav_file in wav_files]
output_files = [wav_file.replace(data_path, args.output_path).replace(".wav", ".npy") for wav_file in wav_files]
for output_file in output_files:
os.makedirs(os.path.dirname(output_file), exist_ok=True)
# define Encoder model
model = SpeakerEncoder(**c.model)
model.load_state_dict(torch.load(args.model_path)['model'])
model.load_state_dict(torch.load(args.model_path)["model"])
model.eval()
if args.use_cuda:
model.cuda()
@ -117,14 +98,14 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
embedd = embedd.detach().cpu().numpy()
np.save(output_files[idx], embedd)
if args.target_dataset != '':
if args.target_dataset != "":
# create speaker_mapping if target dataset is defined
wav_file_name = os.path.basename(wav_file)
speaker_mapping[wav_file_name] = {}
speaker_mapping[wav_file_name]['name'] = speaker_name
speaker_mapping[wav_file_name]['embedding'] = embedd.flatten().tolist()
speaker_mapping[wav_file_name]["name"] = speaker_name
speaker_mapping[wav_file_name]["embedding"] = embedd.flatten().tolist()
if args.target_dataset != '':
if args.target_dataset != "":
# save speaker_mapping if target dataset is defined
mapping_file_path = os.path.join(args.output_path, 'speakers.json')
mapping_file_path = os.path.join(args.output_path, "speakers.json")
save_speaker_mapping(args.output_path, speaker_mapping)

View File

@ -15,25 +15,24 @@ from TTS.utils.audio import AudioProcessor
def main():
"""Run preprocessing process."""
parser = argparse.ArgumentParser(
description="Compute mean and variance of spectrogtram features.")
parser.add_argument("--config_path", type=str, required=True,
help="TTS config file path to define audio processin parameters.")
parser.add_argument("--out_path", type=str, required=True,
help="save path (directory and filename).")
parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.")
parser.add_argument(
"--config_path", type=str, required=True, help="TTS config file path to define audio processin parameters."
)
parser.add_argument("--out_path", type=str, required=True, help="save path (directory and filename).")
args = parser.parse_args()
# load config
CONFIG = load_config(args.config_path)
CONFIG.audio['signal_norm'] = False # do not apply earlier normalization
CONFIG.audio['stats_path'] = None # discard pre-defined stats
CONFIG.audio["signal_norm"] = False # do not apply earlier normalization
CONFIG.audio["stats_path"] = None # discard pre-defined stats
# load audio processor
ap = AudioProcessor(**CONFIG.audio)
# load the meta data of target dataset
if 'data_path' in CONFIG.keys():
dataset_items = glob.glob(os.path.join(CONFIG.data_path, '**', '*.wav'), recursive=True)
if "data_path" in CONFIG.keys():
dataset_items = glob.glob(os.path.join(CONFIG.data_path, "**", "*.wav"), recursive=True)
else:
dataset_items = load_meta_data(CONFIG.datasets)[0] # take only train data
print(f" > There are {len(dataset_items)} files.")
@ -63,27 +62,27 @@ def main():
output_file_path = args.out_path
stats = {}
stats['mel_mean'] = mel_mean
stats['mel_std'] = mel_scale
stats['linear_mean'] = linear_mean
stats['linear_std'] = linear_scale
stats["mel_mean"] = mel_mean
stats["mel_std"] = mel_scale
stats["linear_mean"] = linear_mean
stats["linear_std"] = linear_scale
print(f' > Avg mel spec mean: {mel_mean.mean()}')
print(f' > Avg mel spec scale: {mel_scale.mean()}')
print(f' > Avg linear spec mean: {linear_mean.mean()}')
print(f' > Avg lienar spec scale: {linear_scale.mean()}')
print(f" > Avg mel spec mean: {mel_mean.mean()}")
print(f" > Avg mel spec scale: {mel_scale.mean()}")
print(f" > Avg linear spec mean: {linear_mean.mean()}")
print(f" > Avg lienar spec scale: {linear_scale.mean()}")
# set default config values for mean-var scaling
CONFIG.audio['stats_path'] = output_file_path
CONFIG.audio['signal_norm'] = True
CONFIG.audio["stats_path"] = output_file_path
CONFIG.audio["signal_norm"] = True
# remove redundant values
del CONFIG.audio['max_norm']
del CONFIG.audio['min_level_db']
del CONFIG.audio['symmetric_norm']
del CONFIG.audio['clip_norm']
stats['audio_config'] = CONFIG.audio
del CONFIG.audio["max_norm"]
del CONFIG.audio["min_level_db"]
del CONFIG.audio["symmetric_norm"]
del CONFIG.audio["clip_norm"]
stats["audio_config"] = CONFIG.audio
np.save(output_file_path, stats, allow_pickle=True)
print(f' > stats saved to {output_file_path}')
print(f" > stats saved to {output_file_path}")
if __name__ == "__main__":

View File

@ -9,15 +9,9 @@ from TTS.vocoder.tf.utils.tflite import convert_melgan_to_tflite
parser = argparse.ArgumentParser()
parser.add_argument('--tf_model',
type=str,
help='Path to target torch model to be converted to TF.')
parser.add_argument('--config_path',
type=str,
help='Path to config file of torch model.')
parser.add_argument('--output_path',
type=str,
help='path to tflite output binary.')
parser.add_argument("--tf_model", type=str, help="Path to target torch model to be converted to TF.")
parser.add_argument("--config_path", type=str, help="Path to config file of torch model.")
parser.add_argument("--output_path", type=str, help="path to tflite output binary.")
args = parser.parse_args()
# Set constants

View File

@ -8,27 +8,22 @@ import torch
from TTS.utils.io import load_config
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf)
from TTS.vocoder.tf.utils.generic_utils import \
setup_generator as setup_tf_generator
compare_torch_tf,
convert_tf_name,
transfer_weights_torch_to_tf,
)
from TTS.vocoder.tf.utils.generic_utils import setup_generator as setup_tf_generator
from TTS.vocoder.tf.utils.io import save_checkpoint
from TTS.vocoder.utils.generic_utils import setup_generator
# prevent GPU use
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ["CUDA_VISIBLE_DEVICES"] = ""
# define args
parser = argparse.ArgumentParser()
parser.add_argument('--torch_model_path',
type=str,
help='Path to target torch model to be converted to TF.')
parser.add_argument('--config_path',
type=str,
help='Path to config file of torch model.')
parser.add_argument(
'--output_path',
type=str,
help='path to output file including file name to save TF model.')
parser.add_argument("--torch_model_path", type=str, help="Path to target torch model to be converted to TF.")
parser.add_argument("--config_path", type=str, help="Path to config file of torch model.")
parser.add_argument("--output_path", type=str, help="path to output file including file name to save TF model.")
args = parser.parse_args()
# load model config
@ -38,9 +33,8 @@ num_speakers = 0
# init torch model
model = setup_generator(c)
checkpoint = torch.load(args.torch_model_path,
map_location=torch.device('cpu'))
state_dict = checkpoint['model']
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
state_dict = checkpoint["model"]
model.load_state_dict(state_dict)
model.remove_weight_norm()
state_dict = model.state_dict()
@ -48,7 +42,7 @@ state_dict = model.state_dict()
# init tf model
model_tf = setup_tf_generator(c)
common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
common_sufix = "/.ATTRIBUTES/VARIABLE_VALUE"
# get tf_model graph by passing an input
# B x D x T
dummy_input = tf.random.uniform((7, 80, 64), dtype=tf.float32)
@ -66,10 +60,7 @@ for tf_name in tf_var_names:
if tf_name in [name[0] for name in var_map]:
continue
tf_name_edited = convert_tf_name(tf_name)
ratios = [
SequenceMatcher(None, torch_name, tf_name_edited).ratio()
for torch_name in torch_var_names
]
ratios = [SequenceMatcher(None, torch_name, tf_name_edited).ratio() for torch_name in torch_var_names]
max_idx = np.argmax(ratios)
matching_name = torch_var_names[max_idx]
del torch_var_names[max_idx]
@ -107,10 +98,8 @@ model.inference_padding = 0
model_tf.inference_padding = 0
output_torch = model.inference(dummy_input_torch)
output_tf = model_tf(dummy_input_tf, training=False)
assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf(
output_torch, output_tf)
assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf(output_torch, output_tf)
# save tf model
save_checkpoint(model_tf, checkpoint['step'], checkpoint['epoch'],
args.output_path)
print(' > Model conversion is successfully completed :).')
save_checkpoint(model_tf, checkpoint["step"], checkpoint["epoch"], args.output_path)
print(" > Model conversion is successfully completed :).")

View File

@ -10,15 +10,9 @@ from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite
parser = argparse.ArgumentParser()
parser.add_argument('--tf_model',
type=str,
help='Path to target torch model to be converted to TF.')
parser.add_argument('--config_path',
type=str,
help='Path to config file of torch model.')
parser.add_argument('--output_path',
type=str,
help='path to tflite output binary.')
parser.add_argument("--tf_model", type=str, help="Path to target torch model to be converted to TF.")
parser.add_argument("--config_path", type=str, help="Path to config file of torch model.")
parser.add_argument("--output_path", type=str, help="path to tflite output binary.")
args = parser.parse_args()
# Set constants

View File

@ -8,27 +8,20 @@ import numpy as np
import tensorflow as tf
import torch
from TTS.tts.tf.models.tacotron2 import Tacotron2
from TTS.tts.tf.utils.convert_torch_to_tf_utils import (
compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf)
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
from TTS.tts.tf.utils.generic_utils import save_checkpoint
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.text.symbols import phonemes, symbols
from TTS.utils.io import load_config
sys.path.append('/home/erogol/Projects')
os.environ['CUDA_VISIBLE_DEVICES'] = ''
sys.path.append("/home/erogol/Projects")
os.environ["CUDA_VISIBLE_DEVICES"] = ""
parser = argparse.ArgumentParser()
parser.add_argument('--torch_model_path',
type=str,
help='Path to target torch model to be converted to TF.')
parser.add_argument('--config_path',
type=str,
help='Path to config file of torch model.')
parser.add_argument('--output_path',
type=str,
help='path to output file including file name to save TF model.')
parser.add_argument("--torch_model_path", type=str, help="Path to target torch model to be converted to TF.")
parser.add_argument("--config_path", type=str, help="Path to config file of torch model.")
parser.add_argument("--output_path", type=str, help="path to output file including file name to save TF model.")
args = parser.parse_args()
# load model config
@ -39,51 +32,48 @@ num_speakers = 0
# init torch model
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = setup_model(num_chars, num_speakers, c)
checkpoint = torch.load(args.torch_model_path,
map_location=torch.device('cpu'))
state_dict = checkpoint['model']
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
state_dict = checkpoint["model"]
model.load_state_dict(state_dict)
# init tf model
model_tf = Tacotron2(num_chars=num_chars,
num_speakers=num_speakers,
r=model.decoder.r,
postnet_output_dim=c.audio['num_mels'],
decoder_output_dim=c.audio['num_mels'],
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder)
model_tf = Tacotron2(
num_chars=num_chars,
num_speakers=num_speakers,
r=model.decoder.r,
postnet_output_dim=c.audio["num_mels"],
decoder_output_dim=c.audio["num_mels"],
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder,
)
# set initial layer mapping - these are not captured by the below heuristic approach
# TODO: set layer names so that we can remove these manual matching
common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
common_sufix = "/.ATTRIBUTES/VARIABLE_VALUE"
var_map = [
('embedding/embeddings:0', 'embedding.weight'),
('encoder/lstm/forward_lstm/lstm_cell_1/kernel:0',
'encoder.lstm.weight_ih_l0'),
('encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0',
'encoder.lstm.weight_hh_l0'),
('encoder/lstm/backward_lstm/lstm_cell_2/kernel:0',
'encoder.lstm.weight_ih_l0_reverse'),
('encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0',
'encoder.lstm.weight_hh_l0_reverse'),
('encoder/lstm/forward_lstm/lstm_cell_1/bias:0',
('encoder.lstm.bias_ih_l0', 'encoder.lstm.bias_hh_l0')),
('encoder/lstm/backward_lstm/lstm_cell_2/bias:0',
('encoder.lstm.bias_ih_l0_reverse', 'encoder.lstm.bias_hh_l0_reverse')),
('attention/v/kernel:0', 'decoder.attention.v.linear_layer.weight'),
('decoder/linear_projection/kernel:0',
'decoder.linear_projection.linear_layer.weight'),
('decoder/stopnet/kernel:0', 'decoder.stopnet.1.linear_layer.weight')
("embedding/embeddings:0", "embedding.weight"),
("encoder/lstm/forward_lstm/lstm_cell_1/kernel:0", "encoder.lstm.weight_ih_l0"),
("encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0", "encoder.lstm.weight_hh_l0"),
("encoder/lstm/backward_lstm/lstm_cell_2/kernel:0", "encoder.lstm.weight_ih_l0_reverse"),
("encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0", "encoder.lstm.weight_hh_l0_reverse"),
("encoder/lstm/forward_lstm/lstm_cell_1/bias:0", ("encoder.lstm.bias_ih_l0", "encoder.lstm.bias_hh_l0")),
(
"encoder/lstm/backward_lstm/lstm_cell_2/bias:0",
("encoder.lstm.bias_ih_l0_reverse", "encoder.lstm.bias_hh_l0_reverse"),
),
("attention/v/kernel:0", "decoder.attention.v.linear_layer.weight"),
("decoder/linear_projection/kernel:0", "decoder.linear_projection.linear_layer.weight"),
("decoder/stopnet/kernel:0", "decoder.stopnet.1.linear_layer.weight"),
]
# %%
@ -101,10 +91,7 @@ for tf_name in tf_var_names:
if tf_name in [name[0] for name in var_map]:
continue
tf_name_edited = convert_tf_name(tf_name)
ratios = [
SequenceMatcher(None, torch_name, tf_name_edited).ratio()
for torch_name in torch_var_names
]
ratios = [SequenceMatcher(None, torch_name, tf_name_edited).ratio() for torch_name in torch_var_names]
max_idx = np.argmax(ratios)
matching_name = torch_var_names[max_idx]
del torch_var_names[max_idx]
@ -124,25 +111,21 @@ input_ids = torch.randint(0, 24, (1, 128)).long()
o_t = model.embedding(input_ids)
o_tf = model_tf.embedding(input_ids.detach().numpy())
assert abs(o_t.detach().numpy() -
o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() -
o_tf.numpy()).sum()
assert abs(o_t.detach().numpy() - o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() - o_tf.numpy()).sum()
# compare encoder outputs
oo_en = model.encoder.inference(o_t.transpose(1, 2))
ooo_en = model_tf.encoder(o_t.detach().numpy(), training=False)
assert compare_torch_tf(oo_en, ooo_en) < 1e-5
#pylint: disable=redefined-builtin
# pylint: disable=redefined-builtin
# compare decoder.attention_rnn
inp = torch.rand([1, 768])
inp_tf = inp.numpy()
model.decoder._init_states(oo_en, mask=None) #pylint: disable=protected-access
model.decoder._init_states(oo_en, mask=None) # pylint: disable=protected-access
output, cell_state = model.decoder.attention_rnn(inp)
states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf,
states[2],
training=False)
output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf, states[2], training=False)
assert compare_torch_tf(output, output_tf).mean() < 1e-5
query = output
@ -153,8 +136,7 @@ inputs_tf = inputs.numpy()
# compare decoder.attention
model.decoder.attention.init_states(inputs)
processes_inputs = model.decoder.attention.preprocess_inputs(inputs)
loc_attn, proc_query = model.decoder.attention.get_location_attention(
query, processes_inputs)
loc_attn, proc_query = model.decoder.attention.get_location_attention(query, processes_inputs)
context = model.decoder.attention(query, inputs, processes_inputs, None)
attention_states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)[-1]
@ -169,13 +151,10 @@ assert compare_torch_tf(context, context_tf) < 1e-5
# compare decoder.decoder_rnn
input = torch.rand([1, 1536])
input_tf = input.numpy()
model.decoder._init_states(oo_en, mask=None) #pylint: disable=protected-access
output, cell_state = model.decoder.decoder_rnn(
input, [model.decoder.decoder_hidden, model.decoder.decoder_cell])
model.decoder._init_states(oo_en, mask=None) # pylint: disable=protected-access
output, cell_state = model.decoder.decoder_rnn(input, [model.decoder.decoder_hidden, model.decoder.decoder_cell])
states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf,
states[3],
training=False)
output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf, states[3], training=False)
assert abs(input - input_tf).mean() < 1e-5
assert compare_torch_tf(output, output_tf).mean() < 1e-5
@ -198,12 +177,10 @@ assert compare_torch_tf(output.transpose(1, 2), output_tf) < 1e-4
outputs_torch = model.inference(input_ids)
outputs_tf = model_tf(tf.convert_to_tensor(input_ids.numpy()))
print(abs(outputs_torch[0].numpy()[:, 0] - outputs_tf[0].numpy()[:, 0]).mean())
assert compare_torch_tf(outputs_torch[2][:, 50, :],
outputs_tf[2][:, 50, :]) < 1e-5
assert compare_torch_tf(outputs_torch[2][:, 50, :], outputs_tf[2][:, 50, :]) < 1e-5
assert compare_torch_tf(outputs_torch[0], outputs_tf[0]) < 1e-4
# %%
# save tf model
save_checkpoint(model_tf, None, checkpoint['step'], checkpoint['epoch'],
checkpoint['r'], args.output_path)
print(' > Model conversion is successfully completed :).')
save_checkpoint(model_tf, None, checkpoint["step"], checkpoint["epoch"], checkpoint["r"], args.output_path)
print(" > Model conversion is successfully completed :).")

View File

@ -15,26 +15,19 @@ def main():
Call train.py as a new process and pass command arguments
"""
parser = argparse.ArgumentParser()
parser.add_argument("--script", type=str, help="Target training script to distibute.")
parser.add_argument(
'--script',
type=str,
help='Target training script to distibute.')
parser.add_argument(
'--continue_path',
"--continue_path",
type=str,
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
default='',
required='--config_path' not in sys.argv)
default="",
required="--config_path" not in sys.argv,
)
parser.add_argument(
'--restore_path',
type=str,
help='Model file to be restored. Use to finetune a model.',
default='')
"--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default=""
)
parser.add_argument(
'--config_path',
type=str,
help='Path to config file for training.',
required='--continue_path' not in sys.argv
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in sys.argv
)
args = parser.parse_args()
@ -44,20 +37,20 @@ def main():
# set arguments for train.py
folder_path = pathlib.Path(__file__).parent.absolute()
command = [os.path.join(folder_path, args.script)]
command.append('--continue_path={}'.format(args.continue_path))
command.append('--restore_path={}'.format(args.restore_path))
command.append('--config_path={}'.format(args.config_path))
command.append('--group_id=group_{}'.format(group_id))
command.append('')
command.append("--continue_path={}".format(args.continue_path))
command.append("--restore_path={}".format(args.restore_path))
command.append("--config_path={}".format(args.config_path))
command.append("--group_id=group_{}".format(group_id))
command.append("")
# run processes
processes = []
for i in range(num_gpus):
my_env = os.environ.copy()
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
command[-1] = '--rank={}'.format(i)
stdout = None if i == 0 else open(os.devnull, 'w')
p = subprocess.Popen(['python3'] + command, stdout=stdout, env=my_env)
command[-1] = "--rank={}".format(i)
stdout = None if i == 0 else open(os.devnull, "w")
p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env)
processes.append(p)
print(command)
@ -65,5 +58,5 @@ def main():
p.wait()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -7,30 +7,24 @@ from TTS.tts.datasets.preprocess import get_preprocessor_by_name
def main():
# pylint: disable=bad-continuation
parser = argparse.ArgumentParser(description='''Find all the unique characters or phonemes in a dataset.\n\n'''
'''Target dataset must be defined in TTS.tts.datasets.preprocess\n\n'''\
'''
# pylint: disable=bad-option-value
parser = argparse.ArgumentParser(
description="""Find all the unique characters or phonemes in a dataset.\n\n"""
"""Target dataset must be defined in TTS.tts.datasets.preprocess\n\n"""
"""
Example runs:
python TTS/bin/find_unique_chars.py --dataset ljspeech --meta_file /path/to/LJSpeech/metadata.csv
''', formatter_class=RawTextHelpFormatter)
parser.add_argument(
'--dataset',
type=str,
default='',
help='One of the target dataset names in TTS.tts.datasets.preprocess.'
)
parser.add_argument(
'--meta_file',
type=str,
default=None,
help='Path to the transcriptions file of the dataset.'
""",
formatter_class=RawTextHelpFormatter,
)
parser.add_argument(
"--dataset", type=str, default="", help="One of the target dataset names in TTS.tts.datasets.preprocess."
)
parser.add_argument("--meta_file", type=str, default=None, help="Path to the transcriptions file of the dataset.")
args = parser.parse_args()
preprocessor = get_preprocessor_by_name(args.dataset)

View File

@ -7,15 +7,17 @@ from argparse import RawTextHelpFormatter
from multiprocessing import Pool
from tqdm import tqdm
def resample_file(func_args):
filename, output_sr = func_args
y, sr = librosa.load(filename, sr=output_sr)
librosa.output.write_wav(filename, y, sr)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='''Resample a folder recusively with librosa
description="""Resample a folder recusively with librosa
Can be used in place or create a copy of the folder as an output.\n\n
Example run:
python TTS/bin/resample.py
@ -23,46 +25,52 @@ if __name__ == '__main__':
--output_sr 22050
--output_dir /root/resampled_LJSpeech-1.1/
--n_jobs 24
''',
formatter_class=RawTextHelpFormatter)
""",
formatter_class=RawTextHelpFormatter,
)
parser.add_argument('--input_dir',
type=str,
default=None,
required=True,
help='Path of the folder containing the audio files to resample')
parser.add_argument(
"--input_dir",
type=str,
default=None,
required=True,
help="Path of the folder containing the audio files to resample",
)
parser.add_argument('--output_sr',
type=int,
default=22050,
required=False,
help='Samlple rate to which the audio files should be resampled')
parser.add_argument(
"--output_sr",
type=int,
default=22050,
required=False,
help="Samlple rate to which the audio files should be resampled",
)
parser.add_argument('--output_dir',
type=str,
default=None,
required=False,
help='Path of the destination folder. If not defined, the operation is done in place')
parser.add_argument(
"--output_dir",
type=str,
default=None,
required=False,
help="Path of the destination folder. If not defined, the operation is done in place",
)
parser.add_argument('--n_jobs',
type=int,
default=None,
help='Number of threads to use, by default it uses all cores')
parser.add_argument(
"--n_jobs", type=int, default=None, help="Number of threads to use, by default it uses all cores"
)
args = parser.parse_args()
if args.output_dir:
print('Recursively copying the input folder...')
print("Recursively copying the input folder...")
copy_tree(args.input_dir, args.output_dir)
args.input_dir = args.output_dir
print('Resampling the audio files...')
audio_files = glob.glob(os.path.join(args.input_dir, '**/*.wav'), recursive=True)
print(f'Found {len(audio_files)} files...')
audio_files = list(zip(audio_files, len(audio_files)*[args.output_sr]))
print("Resampling the audio files...")
audio_files = glob.glob(os.path.join(args.input_dir, "**/*.wav"), recursive=True)
print(f"Found {len(audio_files)} files...")
audio_files = list(zip(audio_files, len(audio_files) * [args.output_sr]))
with Pool(processes=args.n_jobs) as p:
with tqdm(total=len(audio_files)) as pbar:
for i, _ in enumerate(p.imap_unordered(resample_file, audio_files)):
pbar.update()
print('Done !')
print("Done !")

View File

@ -4,6 +4,7 @@
import argparse
import sys
from argparse import RawTextHelpFormatter
# pylint: disable=redefined-outer-name, unused-argument
from pathlib import Path
@ -14,22 +15,20 @@ from TTS.utils.synthesizer import Synthesizer
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
if v.lower() in ('no', 'false', 'f', 'n', '0'):
if v.lower() in ("no", "false", "f", "n", "0"):
return False
raise argparse.ArgumentTypeError('Boolean value expected.')
raise argparse.ArgumentTypeError("Boolean value expected.")
def main():
# pylint: disable=bad-continuation
parser = argparse.ArgumentParser(description='''Synthesize speech on command line.\n\n'''
'''You can either use your trained model or choose a model from the provided list.\n\n'''\
'''If you don't specify any models, then it uses LJSpeech based English models\n\n'''\
'''
# pylint: disable=bad-option-value
parser = argparse.ArgumentParser(
description="""Synthesize speech on command line.\n\n"""
"""You can either use your trained model or choose a model from the provided list.\n\n"""
"""If you don't specify any models, then it uses LJSpeech based English models\n\n"""
"""
Example runs:
# list provided models
@ -51,106 +50,80 @@ def main():
./TTS/bin/synthesize.py --text "Text for TTS" --model_path path/to/config.json --config_path path/to/model.pth.tar --out_path output/path/speech.wav
--vocoder_path path/to/vocoder.pth.tar --vocoder_config_path path/to/vocoder_config.json
''',
formatter_class=RawTextHelpFormatter)
""",
formatter_class=RawTextHelpFormatter,
)
parser.add_argument(
'--list_models',
"--list_models",
type=str2bool,
nargs='?',
nargs="?",
const=True,
default=False,
help='list available pre-trained tts and vocoder models.'
)
parser.add_argument(
'--text',
type=str,
default=None,
help='Text to generate speech.'
)
help="list available pre-trained tts and vocoder models.",
)
parser.add_argument("--text", type=str, default=None, help="Text to generate speech.")
# Args for running pre-trained TTS models.
parser.add_argument(
'--model_name',
"--model_name",
type=str,
default="tts_models/en/ljspeech/speedy-speech-wn",
help=
'Name of one of the pre-trained tts models in format <language>/<dataset>/<model_name>'
help="Name of one of the pre-trained tts models in format <language>/<dataset>/<model_name>",
)
parser.add_argument(
'--vocoder_name',
"--vocoder_name",
type=str,
default=None,
help=
'Name of one of the pre-trained vocoder models in format <language>/<dataset>/<model_name>'
help="Name of one of the pre-trained vocoder models in format <language>/<dataset>/<model_name>",
)
# Args for running custom models
parser.add_argument("--config_path", default=None, type=str, help="Path to model config file.")
parser.add_argument(
'--config_path',
default=None,
type=str,
help='Path to model config file.'
)
parser.add_argument(
'--model_path',
"--model_path",
type=str,
default=None,
help='Path to model file.',
help="Path to model file.",
)
parser.add_argument(
'--out_path',
"--out_path",
type=str,
default='tts_output.wav',
help='Output wav file path.',
default="tts_output.wav",
help="Output wav file path.",
)
parser.add_argument("--use_cuda", type=bool, help="Run model on CUDA.", default=False)
parser.add_argument(
'--use_cuda',
type=bool,
help='Run model on CUDA.',
default=False
)
parser.add_argument(
'--vocoder_path',
"--vocoder_path",
type=str,
help=
'Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).',
help="Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).",
default=None,
)
parser.add_argument(
'--vocoder_config_path',
type=str,
help='Path to vocoder model config file.',
default=None)
parser.add_argument("--vocoder_config_path", type=str, help="Path to vocoder model config file.", default=None)
# args for multi-speaker synthesis
parser.add_argument("--speakers_json", type=str, help="JSON file for multi-speaker model.", default=None)
parser.add_argument(
'--speakers_json',
type=str,
help="JSON file for multi-speaker model.",
default=None)
parser.add_argument(
'--speaker_idx',
"--speaker_idx",
type=str,
help="if the tts model is trained with x-vectors, then speaker_idx is a file present in speakers.json else speaker_idx is the speaker id corresponding to a speaker in the speaker embedding layer.",
default=None)
parser.add_argument(
'--gst_style',
help="Wav path file for GST stylereference.",
default=None)
default=None,
)
parser.add_argument("--gst_style", help="Wav path file for GST stylereference.", default=None)
# aux args
parser.add_argument(
'--save_spectogram',
"--save_spectogram",
type=bool,
help="If true save raw spectogram for further (vocoder) processing in out_path.",
default=False)
default=False,
)
args = parser.parse_args()
# print the description if either text or list_models is not set
if args.text is None and not args.list_models:
parser.parse_args(['-h'])
parser.parse_args(["-h"])
# load model manager
path = Path(__file__).parent / "../.models.json"
@ -169,7 +142,7 @@ def main():
# CASE2: load pre-trained models
if args.model_name is not None:
model_path, config_path, model_item = manager.download_model(args.model_name)
args.vocoder_name = model_item['default_vocoder'] if args.vocoder_name is None else args.vocoder_name
args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
if args.vocoder_name is not None:
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)

View File

@ -25,12 +25,11 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.arguments import parse_arguments, process_args
from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import init_distributed, reduce_tensor
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
remove_experiment_folder, set_init_dict)
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.radam import RAdam
from TTS.utils.training import NoamLR, setup_torch_training_env
if __name__ == '__main__':
if __name__ == "__main__":
use_cuda, num_gpus = setup_torch_training_env(True, False)
# torch.autograd.set_detect_anomaly(True)
@ -44,10 +43,9 @@ if __name__ == '__main__':
compute_linear_spec=False,
meta_data=meta_data_eval if is_val else meta_data_train,
ap=ap,
tp=c.characters if 'characters' in c.keys() else None,
add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
batch_group_size=0 if is_val else c.batch_group_size *
c.batch_size,
tp=c.characters if "characters" in c.keys() else None,
add_blank=c["add_blank"] if "add_blank" in c.keys() else False,
batch_group_size=0 if is_val else c.batch_group_size * c.batch_size,
min_seq_len=c.min_seq_len,
max_seq_len=c.max_seq_len,
phoneme_cache_path=c.phoneme_cache_path,
@ -56,8 +54,10 @@ if __name__ == '__main__':
enable_eos_bos=c.enable_eos_bos_chars,
use_noise_augment=not is_val,
verbose=verbose,
speaker_mapping=speaker_mapping if c.use_speaker_embedding
and c.use_external_speaker_embedding_file else None)
speaker_mapping=speaker_mapping
if c.use_speaker_embedding and c.use_external_speaker_embedding_file
else None,
)
if c.use_phonemes and c.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths.
@ -72,9 +72,9 @@ if __name__ == '__main__':
collate_fn=dataset.collate_fn,
drop_last=False,
sampler=sampler,
num_workers=c.num_val_loader_workers
if is_val else c.num_loader_workers,
pin_memory=False)
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
pin_memory=False,
)
return loader
def format_data(data):
@ -94,10 +94,7 @@ if __name__ == '__main__':
speaker_c = data[8]
else:
# return speaker_id to be used by an embedding layer
speaker_c = [
speaker_mapping[speaker_name]
for speaker_name in speaker_names
]
speaker_c = [speaker_mapping[speaker_name] for speaker_name in speaker_names]
speaker_c = torch.LongTensor(speaker_c)
else:
speaker_c = None
@ -109,18 +106,15 @@ if __name__ == '__main__':
mel_lengths = mel_lengths.cuda(non_blocking=True)
if speaker_c is not None:
speaker_c = speaker_c.cuda(non_blocking=True)
return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
avg_text_length, avg_spec_length, item_idx
return text_input, text_lengths, mel_input, mel_lengths, speaker_c, avg_text_length, avg_spec_length, item_idx
def train(data_loader, model, criterion, optimizer, scheduler, ap,
global_step, epoch, training_phase):
def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch, training_phase):
model.train()
epoch_time = 0
keep_avg = KeepAverage()
if use_cuda:
batch_n_iter = int(
len(data_loader.dataset) / (c.batch_size * num_gpus))
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
else:
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
end_time = time.time()
@ -130,8 +124,16 @@ if __name__ == '__main__':
start_time = time.time()
# format data
text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\
avg_text_length, avg_spec_length, _ = format_data(data)
(
text_input,
text_lengths,
mel_targets,
mel_lengths,
speaker_c,
avg_text_length,
avg_spec_length,
_,
) = format_data(data)
loader_time = time.time() - end_time
@ -141,36 +143,32 @@ if __name__ == '__main__':
# forward pass model
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
decoder_output, dur_output, dur_mas_output, alignments, _, _, logp = model.forward(
text_input,
text_lengths,
mel_targets,
mel_lengths,
g=speaker_c,
phase=training_phase)
text_input, text_lengths, mel_targets, mel_lengths, g=speaker_c, phase=training_phase
)
# compute loss
loss_dict = criterion(logp,
decoder_output,
mel_targets,
mel_lengths,
dur_output,
dur_mas_output,
text_lengths,
global_step,
phase=training_phase)
loss_dict = criterion(
logp,
decoder_output,
mel_targets,
mel_lengths,
dur_output,
dur_mas_output,
text_lengths,
global_step,
phase=training_phase,
)
# backward pass with loss scaling
if c.mixed_precision:
scaler.scale(loss_dict['loss']).backward()
scaler.scale(loss_dict["loss"]).backward()
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), c.grad_clip)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
scaler.step(optimizer)
scaler.update()
else:
loss_dict['loss'].backward()
grad_norm = torch.nn.utils.clip_grad_norm_(
model.parameters(), c.grad_clip)
loss_dict["loss"].backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
optimizer.step()
# setup lr
@ -178,25 +176,21 @@ if __name__ == '__main__':
scheduler.step()
# current_lr
current_lr = optimizer.param_groups[0]['lr']
current_lr = optimizer.param_groups[0]["lr"]
# compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
loss_dict['align_error'] = align_error
loss_dict["align_error"] = align_error
step_time = time.time() - start_time
epoch_time += step_time
# aggregate losses from processes
if num_gpus > 1:
loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data,
num_gpus)
loss_dict['loss_ssim'] = reduce_tensor(
loss_dict['loss_ssim'].data, num_gpus)
loss_dict['loss_dur'] = reduce_tensor(
loss_dict['loss_dur'].data, num_gpus)
loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data,
num_gpus)
loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus)
loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus)
loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus)
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
# detach loss values
loss_dict_new = dict()
@ -210,48 +204,43 @@ if __name__ == '__main__':
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
update_train_values['avg_' + key] = value
update_train_values['avg_loader_time'] = loader_time
update_train_values['avg_step_time'] = step_time
update_train_values["avg_" + key] = value
update_train_values["avg_loader_time"] = loader_time
update_train_values["avg_step_time"] = step_time
keep_avg.update_values(update_train_values)
# print training progress
if global_step % c.print_step == 0:
log_dict = {
"avg_spec_length": [avg_spec_length,
1], # value, precision
"avg_spec_length": [avg_spec_length, 1], # value, precision
"avg_text_length": [avg_text_length, 1],
"step_time": [step_time, 4],
"loader_time": [loader_time, 2],
"current_lr": current_lr,
}
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
log_dict, loss_dict,
keep_avg.avg_values)
c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values)
if args.rank == 0:
# Plot Training Iter Stats
# reduce TB load
if global_step % c.tb_plot_step == 0:
iter_stats = {
"lr": current_lr,
"grad_norm": grad_norm,
"step_time": step_time
}
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
if global_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model,
optimizer,
global_step,
epoch,
1,
OUT_PATH,
model_characters,
model_loss=loss_dict['loss'])
save_checkpoint(
model,
optimizer,
global_step,
epoch,
1,
OUT_PATH,
model_characters,
model_loss=loss_dict["loss"],
)
# wait all kernels to be completed
torch.cuda.synchronize()
@ -259,8 +248,7 @@ if __name__ == '__main__':
# Diagnostic visualizations
if decoder_output is not None:
idx = np.random.randint(mel_targets.shape[0])
pred_spec = decoder_output[idx].detach().data.cpu(
).numpy().T
pred_spec = decoder_output[idx].detach().data.cpu().numpy().T
gt_spec = mel_targets[idx].data.cpu().numpy().T
align_img = alignments[idx].data.cpu()
@ -274,14 +262,11 @@ if __name__ == '__main__':
# Sample audio
train_audio = ap.inv_melspectrogram(pred_spec.T)
tb_logger.tb_train_audios(global_step,
{'TrainAudio': train_audio},
c.audio["sample_rate"])
tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, c.audio["sample_rate"])
end_time = time.time()
# print epoch stats
c_logger.print_train_epoch_end(global_step, epoch, epoch_time,
keep_avg)
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
# Plot Epoch Stats
if args.rank == 0:
@ -293,8 +278,7 @@ if __name__ == '__main__':
return keep_avg.avg_values, global_step
@torch.no_grad()
def evaluate(data_loader, model, criterion, ap, global_step, epoch,
training_phase):
def evaluate(data_loader, model, criterion, ap, global_step, epoch, training_phase):
model.eval()
epoch_time = 0
keep_avg = KeepAverage()
@ -304,50 +288,41 @@ if __name__ == '__main__':
start_time = time.time()
# format data
text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\
_, _, _ = format_data(data)
text_input, text_lengths, mel_targets, mel_lengths, speaker_c, _, _, _ = format_data(data)
# forward pass model
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
decoder_output, dur_output, dur_mas_output, alignments, _, _, logp = model.forward(
text_input,
text_lengths,
mel_targets,
mel_lengths,
g=speaker_c,
phase=training_phase)
text_input, text_lengths, mel_targets, mel_lengths, g=speaker_c, phase=training_phase
)
# compute loss
loss_dict = criterion(logp,
decoder_output,
mel_targets,
mel_lengths,
dur_output,
dur_mas_output,
text_lengths,
global_step,
phase=training_phase)
loss_dict = criterion(
logp,
decoder_output,
mel_targets,
mel_lengths,
dur_output,
dur_mas_output,
text_lengths,
global_step,
phase=training_phase,
)
# step time
step_time = time.time() - start_time
epoch_time += step_time
# compute alignment score
align_error = 1 - alignment_diagonal_score(alignments,
binary=True)
loss_dict['align_error'] = align_error
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
loss_dict["align_error"] = align_error
# aggregate losses from processes
if num_gpus > 1:
loss_dict['loss_l1'] = reduce_tensor(
loss_dict['loss_l1'].data, num_gpus)
loss_dict['loss_ssim'] = reduce_tensor(
loss_dict['loss_ssim'].data, num_gpus)
loss_dict['loss_dur'] = reduce_tensor(
loss_dict['loss_dur'].data, num_gpus)
loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data,
num_gpus)
loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus)
loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus)
loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus)
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
# detach loss values
loss_dict_new = dict()
@ -361,12 +336,11 @@ if __name__ == '__main__':
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
update_train_values['avg_' + key] = value
update_train_values["avg_" + key] = value
keep_avg.update_values(update_train_values)
if c.print_eval:
c_logger.print_eval_step(num_iter, loss_dict,
keep_avg.avg_values)
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
if args.rank == 0:
# Diagnostic visualizations
@ -376,19 +350,14 @@ if __name__ == '__main__':
align_img = alignments[idx].data.cpu()
eval_figures = {
"prediction": plot_spectrogram(pred_spec,
ap,
output_fig=False),
"ground_truth": plot_spectrogram(gt_spec,
ap,
output_fig=False),
"alignment": plot_alignment(align_img, output_fig=False)
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
"alignment": plot_alignment(align_img, output_fig=False),
}
# Sample audio
eval_audio = ap.inv_melspectrogram(pred_spec.T)
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
c.audio["sample_rate"])
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"])
# Plot Validation Stats
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
@ -401,7 +370,7 @@ if __name__ == '__main__':
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963."
"Prior to November 22, 1963.",
]
else:
with open(c.test_sentences_file, "r") as f:
@ -413,9 +382,9 @@ if __name__ == '__main__':
print(" | > Synthesizing test sentences")
if c.use_speaker_embedding:
if c.use_external_speaker_embedding_file:
speaker_embedding = speaker_mapping[list(
speaker_mapping.keys())[randrange(
len(speaker_mapping) - 1)]]['embedding']
speaker_embedding = speaker_mapping[
list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]
]["embedding"]
speaker_id = None
else:
speaker_id = 0
@ -437,25 +406,22 @@ if __name__ == '__main__':
speaker_embedding=speaker_embedding,
style_wav=style_wav,
truncated=False,
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
enable_eos_bos_chars=c.enable_eos_bos_chars, # pylint: disable=unused-argument
use_griffin_lim=True,
do_trim_silence=False)
do_trim_silence=False,
)
file_path = os.path.join(AUDIO_PATH, str(global_step))
os.makedirs(file_path, exist_ok=True)
file_path = os.path.join(file_path,
"TestSentence_{}.wav".format(idx))
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
ap.save_wav(wav, file_path)
test_audios['{}-audio'.format(idx)] = wav
test_figures['{}-prediction'.format(
idx)] = plot_spectrogram(postnet_output, ap)
test_figures['{}-alignment'.format(idx)] = plot_alignment(
alignment)
except: #pylint: disable=bare-except
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap)
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment)
except: # pylint: disable=bare-except
print(" !! Error creating Test Sentence -", idx)
traceback.print_exc()
tb_logger.tb_test_audios(global_step, test_audios,
c.audio['sample_rate'])
tb_logger.tb_test_audios(global_step, test_audios, c.audio["sample_rate"])
tb_logger.tb_test_figures(global_step, test_figures)
return keep_avg.avg_values
@ -464,69 +430,55 @@ if __name__ == '__main__':
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
# Audio processor
ap = AudioProcessor(**c.audio)
if 'characters' in c.keys():
if "characters" in c.keys():
symbols, phonemes = make_symbols(**c.characters)
# DISTRUBUTED
if num_gpus > 1:
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"])
# set model characters
model_characters = phonemes if c.use_phonemes else symbols
num_chars = len(model_characters)
# load data instances
meta_data_train, meta_data_eval = load_meta_data(c.datasets,
eval_split=True)
meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True)
# set the portion of the data used for training if set in config.json
if 'train_portion' in c.keys():
meta_data_train = meta_data_train[:int(
len(meta_data_train) * c.train_portion)]
if 'eval_portion' in c.keys():
meta_data_eval = meta_data_eval[:int(
len(meta_data_eval) * c.eval_portion)]
if "train_portion" in c.keys():
meta_data_train = meta_data_train[: int(len(meta_data_train) * c.train_portion)]
if "eval_portion" in c.keys():
meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * c.eval_portion)]
# parse speakers
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(
c, args, meta_data_train, OUT_PATH)
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH)
# setup model
model = setup_model(num_chars,
num_speakers,
c,
speaker_embedding_dim=speaker_embedding_dim)
optimizer = RAdam(model.parameters(),
lr=c.lr,
weight_decay=0,
betas=(0.9, 0.98),
eps=1e-9)
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim)
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
criterion = AlignTTSLoss(c)
if args.restore_path:
print(
f" > Restoring from {os.path.basename(args.restore_path)} ...")
checkpoint = torch.load(args.restore_path, map_location='cpu')
print(f" > Restoring from {os.path.basename(args.restore_path)} ...")
checkpoint = torch.load(args.restore_path, map_location="cpu")
try:
# TODO: fix optimizer init, model.cuda() needs to be called before
# optimizer restore
optimizer.load_state_dict(checkpoint['optimizer'])
optimizer.load_state_dict(checkpoint["optimizer"])
if c.reinit_layers:
raise RuntimeError
model.load_state_dict(checkpoint['model'])
except: #pylint: disable=bare-except
model.load_state_dict(checkpoint["model"])
except: # pylint: disable=bare-except
print(" > Partial model initialization.")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
model.load_state_dict(model_dict)
del model_dict
for group in optimizer.param_groups:
group['initial_lr'] = c.lr
print(" > Model restored from step %d" % checkpoint['step'],
flush=True)
args.restore_step = checkpoint['step']
group["initial_lr"] = c.lr
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
args.restore_step = checkpoint["step"]
else:
args.restore_step = 0
@ -539,9 +491,7 @@ if __name__ == '__main__':
model = DDP_th(model, device_ids=[args.rank])
if c.noam_schedule:
scheduler = NoamLR(optimizer,
warmup_steps=c.warmup_steps,
last_epoch=args.restore_step - 1)
scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
else:
scheduler = None
@ -549,16 +499,14 @@ if __name__ == '__main__':
print("\n > Model has {} parameters".format(num_params), flush=True)
if args.restore_step == 0 or not args.best_path:
best_loss = float('inf')
best_loss = float("inf")
print(" > Starting with inf best loss.")
else:
print(" > Restoring best loss from "
f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path,
map_location='cpu')['model_loss']
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
print(f" > Starting with loaded last best loss {best_loss}.")
keep_all_best = c.get('keep_all_best', False)
keep_after = c.get('keep_after', 10000) # void if keep_all_best False
keep_all_best = c.get("keep_all_best", False)
keep_after = c.get("keep_after", 10000) # void if keep_all_best False
# define dataloaders
train_loader = setup_loader(ap, 1, is_val=False, verbose=True)
@ -573,9 +521,9 @@ if __name__ == '__main__':
if not True in vals:
phase = 0
else:
phase = len(c.phase_start_steps) - [
i < global_step for i in c.phase_start_steps
][::-1].index(True) - 1
phase = (
len(c.phase_start_steps) - [i < global_step for i in c.phase_start_steps][::-1].index(True) - 1
)
else:
phase = None
return phase
@ -584,32 +532,30 @@ if __name__ == '__main__':
cur_phase = set_phase()
print(f"\n > Current AlignTTS phase: {cur_phase}")
c_logger.print_epoch_start(epoch, c.epochs)
train_avg_loss_dict, global_step = train(train_loader, model,
criterion, optimizer,
scheduler, ap,
global_step, epoch,
cur_phase)
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
global_step, epoch, cur_phase)
train_avg_loss_dict, global_step = train(
train_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch, cur_phase
)
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch, cur_phase)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_loss']
target_loss = train_avg_loss_dict["avg_loss"]
if c.run_eval:
target_loss = eval_avg_loss_dict['avg_loss']
best_loss = save_best_model(target_loss,
best_loss,
model,
optimizer,
global_step,
epoch,
1,
OUT_PATH,
model_characters,
keep_all_best=keep_all_best,
keep_after=keep_after)
target_loss = eval_avg_loss_dict["avg_loss"]
best_loss = save_best_model(
target_loss,
best_loss,
model,
optimizer,
global_step,
epoch,
1,
OUT_PATH,
model_characters,
keep_all_best=keep_all_best,
keep_after=keep_after,
)
args = parse_arguments(sys.argv)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
args, model_class='tts')
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts")
try:
main(args)

View File

@ -12,14 +12,17 @@ from torch.utils.data import DataLoader
from TTS.speaker_encoder.dataset import MyDataset
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss
from TTS.speaker_encoder.model import SpeakerEncoder
from TTS.speaker_encoder.utils.generic_utils import \
check_config_speaker_encoder, save_best_model
from TTS.speaker_encoder.utils.generic_utils import check_config_speaker_encoder, save_best_model
from TTS.speaker_encoder.utils.visual import plot_embeddings
from TTS.tts.datasets.preprocess import load_meta_data
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import (count_parameters,
create_experiment_folder, get_git_branch,
remove_experiment_folder, set_init_dict)
from TTS.utils.generic_utils import (
count_parameters,
create_experiment_folder,
get_git_branch,
remove_experiment_folder,
set_init_dict,
)
from TTS.utils.io import copy_model_files, load_config
from TTS.utils.radam import RAdam
from TTS.utils.tensorboard_logger import TensorboardLogger
@ -34,28 +37,30 @@ print(" > Using CUDA: ", use_cuda)
print(" > Number of GPUs: ", num_gpus)
def setup_loader(ap: AudioProcessor,
is_val: bool = False,
verbose: bool = False):
def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False):
if is_val:
loader = None
else:
dataset = MyDataset(ap,
meta_data_eval if is_val else meta_data_train,
voice_len=1.6,
num_utter_per_speaker=c.num_utters_per_speaker,
num_speakers_in_batch=c.num_speakers_in_batch,
skip_speakers=False,
storage_size=c.storage["storage_size"],
sample_from_storage_p=c.storage["sample_from_storage_p"],
additive_noise=c.storage["additive_noise"],
verbose=verbose)
dataset = MyDataset(
ap,
meta_data_eval if is_val else meta_data_train,
voice_len=1.6,
num_utter_per_speaker=c.num_utters_per_speaker,
num_speakers_in_batch=c.num_speakers_in_batch,
skip_speakers=False,
storage_size=c.storage["storage_size"],
sample_from_storage_p=c.storage["sample_from_storage_p"],
additive_noise=c.storage["additive_noise"],
verbose=verbose,
)
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(dataset,
batch_size=c.num_speakers_in_batch,
shuffle=False,
num_workers=c.num_loader_workers,
collate_fn=dataset.collate_fn)
loader = DataLoader(
dataset,
batch_size=c.num_speakers_in_batch,
shuffle=False,
num_workers=c.num_loader_workers,
collate_fn=dataset.collate_fn,
)
return loader
@ -63,7 +68,7 @@ def train(model, criterion, optimizer, scheduler, ap, global_step):
data_loader = setup_loader(ap, is_val=False, verbose=True)
model.train()
epoch_time = 0
best_loss = float('inf')
best_loss = float("inf")
avg_loss = 0
avg_loader_time = 0
end_time = time.time()
@ -89,9 +94,7 @@ def train(model, criterion, optimizer, scheduler, ap, global_step):
outputs = model(inputs)
# loss computation
loss = criterion(
outputs.view(c.num_speakers_in_batch,
outputs.shape[0] // c.num_speakers_in_batch, -1))
loss = criterion(outputs.view(c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1))
loss.backward()
grad_norm, _ = check_update(model, c.grad_clip)
optimizer.step()
@ -100,11 +103,13 @@ def train(model, criterion, optimizer, scheduler, ap, global_step):
epoch_time += step_time
# Averaged Loss and Averaged Loader Time
avg_loss = 0.01 * loss.item() \
+ 0.99 * avg_loss if avg_loss != 0 else loss.item()
avg_loader_time = 1/c.num_loader_workers * loader_time + \
(c.num_loader_workers-1) / c.num_loader_workers * avg_loader_time if avg_loader_time != 0 else loader_time
current_lr = optimizer.param_groups[0]['lr']
avg_loss = 0.01 * loss.item() + 0.99 * avg_loss if avg_loss != 0 else loss.item()
avg_loader_time = (
1 / c.num_loader_workers * loader_time + (c.num_loader_workers - 1) / c.num_loader_workers * avg_loader_time
if avg_loader_time != 0
else loader_time
)
current_lr = optimizer.param_groups[0]["lr"]
if global_step % c.steps_plot_stats == 0:
# Plot Training Epoch Stats
@ -113,13 +118,12 @@ def train(model, criterion, optimizer, scheduler, ap, global_step):
"lr": current_lr,
"grad_norm": grad_norm,
"step_time": step_time,
"avg_loader_time": avg_loader_time
"avg_loader_time": avg_loader_time,
}
tb_logger.tb_train_epoch_stats(global_step, train_stats)
figures = {
# FIXME: not constant
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(),
10),
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10),
}
tb_logger.tb_train_figures(global_step, figures)
@ -127,13 +131,13 @@ def train(model, criterion, optimizer, scheduler, ap, global_step):
print(
" | > Step:{} Loss:{:.5f} AvgLoss:{:.5f} GradNorm:{:.5f} "
"StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format(
global_step, loss.item(), avg_loss, grad_norm, step_time,
loader_time, avg_loader_time, current_lr),
flush=True)
global_step, loss.item(), avg_loss, grad_norm, step_time, loader_time, avg_loader_time, current_lr
),
flush=True,
)
# save best model
best_loss = save_best_model(model, optimizer, avg_loss, best_loss,
OUT_PATH, global_step)
best_loss = save_best_model(model, optimizer, avg_loss, best_loss, OUT_PATH, global_step)
end_time = time.time()
return avg_loss, global_step
@ -145,14 +149,16 @@ def main(args): # pylint: disable=redefined-outer-name
global meta_data_eval
ap = AudioProcessor(**c.audio)
model = SpeakerEncoder(input_dim=c.model['input_dim'],
proj_dim=c.model['proj_dim'],
lstm_dim=c.model['lstm_dim'],
num_lstm_layers=c.model['num_lstm_layers'])
model = SpeakerEncoder(
input_dim=c.model["input_dim"],
proj_dim=c.model["proj_dim"],
lstm_dim=c.model["lstm_dim"],
num_lstm_layers=c.model["num_lstm_layers"],
)
optimizer = RAdam(model.parameters(), lr=c.lr)
if c.loss == "ge2e":
criterion = GE2ELoss(loss_method='softmax')
criterion = GE2ELoss(loss_method="softmax")
elif c.loss == "angleproto":
criterion = AngleProtoLoss()
else:
@ -166,7 +172,7 @@ def main(args): # pylint: disable=redefined-outer-name
# optimizer.load_state_dict(checkpoint['optimizer'])
if c.reinit_layers:
raise RuntimeError
model.load_state_dict(checkpoint['model'])
model.load_state_dict(checkpoint["model"])
except KeyError:
print(" > Partial model initialization.")
model_dict = model.state_dict()
@ -174,10 +180,9 @@ def main(args): # pylint: disable=redefined-outer-name
model.load_state_dict(model_dict)
del model_dict
for group in optimizer.param_groups:
group['lr'] = c.lr
print(" > Model restored from step %d" % checkpoint['step'],
flush=True)
args.restore_step = checkpoint['step']
group["lr"] = c.lr
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
args.restore_step = checkpoint["step"]
else:
args.restore_step = 0
@ -186,9 +191,7 @@ def main(args): # pylint: disable=redefined-outer-name
criterion.cuda()
if c.lr_decay:
scheduler = NoamLR(optimizer,
warmup_steps=c.warmup_steps,
last_epoch=args.restore_step - 1)
scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
else:
scheduler = None
@ -199,55 +202,39 @@ def main(args): # pylint: disable=redefined-outer-name
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
global_step = args.restore_step
_, global_step = train(model, criterion, optimizer, scheduler, ap,
global_step)
_, global_step = train(model, criterion, optimizer, scheduler, ap, global_step)
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--restore_path',
type=str,
help='Path to model outputs (checkpoint, tensorboard etc.).',
default=0)
"--restore_path", type=str, help="Path to model outputs (checkpoint, tensorboard etc.).", default=0
)
parser.add_argument(
'--config_path',
"--config_path",
type=str,
required=True,
help='Path to config file for training.',
help="Path to config file for training.",
)
parser.add_argument('--debug',
type=bool,
default=True,
help='Do not verify commit integrity to run training.')
parser.add_argument(
'--data_path',
type=str,
default='',
help='Defines the data path. It overwrites config.json.')
parser.add_argument('--output_path',
type=str,
help='path for training outputs.',
default='')
parser.add_argument('--output_folder',
type=str,
default='',
help='folder name for training outputs.')
parser.add_argument("--debug", type=bool, default=True, help="Do not verify commit integrity to run training.")
parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.")
parser.add_argument("--output_path", type=str, help="path for training outputs.", default="")
parser.add_argument("--output_folder", type=str, default="", help="folder name for training outputs.")
args = parser.parse_args()
# setup output paths and read configs
c = load_config(args.config_path)
check_config_speaker_encoder(c)
_ = os.path.dirname(os.path.realpath(__file__))
if args.data_path != '':
if args.data_path != "":
c.data_path = args.data_path
if args.output_path == '':
if args.output_path == "":
OUT_PATH = os.path.join(_, c.output_path)
else:
OUT_PATH = args.output_path
if args.output_folder == '':
if args.output_folder == "":
OUT_PATH = create_experiment_folder(OUT_PATH, c.run_name, args.debug)
else:
OUT_PATH = os.path.join(OUT_PATH, args.output_folder)
@ -259,7 +246,7 @@ if __name__ == '__main__':
copy_model_files(c, args.config_path, OUT_PATH, new_fields)
LOG_DIR = OUT_PATH
tb_logger = TensorboardLogger(LOG_DIR, model_name='Speaker_Encoder')
tb_logger = TensorboardLogger(LOG_DIR, model_name="Speaker_Encoder")
try:
main(args)

View File

@ -8,6 +8,7 @@ import traceback
from random import randrange
import torch
# DISTRIBUTED
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader
@ -26,8 +27,7 @@ from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import init_distributed, reduce_tensor
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
remove_experiment_folder, set_init_dict)
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.radam import RAdam
from TTS.utils.training import NoamLR, setup_torch_training_env
@ -44,19 +44,21 @@ def setup_loader(ap, r, is_val=False, verbose=False):
compute_linear_spec=False,
meta_data=meta_data_eval if is_val else meta_data_train,
ap=ap,
tp=c.characters if 'characters' in c.keys() else None,
add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
batch_group_size=0 if is_val else c.batch_group_size *
c.batch_size,
tp=c.characters if "characters" in c.keys() else None,
add_blank=c["add_blank"] if "add_blank" in c.keys() else False,
batch_group_size=0 if is_val else c.batch_group_size * c.batch_size,
min_seq_len=c.min_seq_len,
max_seq_len=c.max_seq_len,
phoneme_cache_path=c.phoneme_cache_path,
use_phonemes=c.use_phonemes,
phoneme_language=c.phoneme_language,
enable_eos_bos=c.enable_eos_bos_chars,
use_noise_augment=c['use_noise_augment'] and not is_val,
use_noise_augment=c["use_noise_augment"] and not is_val,
verbose=verbose,
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
speaker_mapping=speaker_mapping
if c.use_speaker_embedding and c.use_external_speaker_embedding_file
else None,
)
if c.use_phonemes and c.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths.
@ -71,9 +73,9 @@ def setup_loader(ap, r, is_val=False, verbose=False):
collate_fn=dataset.collate_fn,
drop_last=False,
sampler=sampler,
num_workers=c.num_val_loader_workers
if is_val else c.num_loader_workers,
pin_memory=False)
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
pin_memory=False,
)
return loader
@ -95,9 +97,7 @@ def format_data(data):
speaker_c = data[8]
else:
# return speaker_id to be used by an embedding layer
speaker_c = [
speaker_mapping[speaker_name] for speaker_name in speaker_names
]
speaker_c = [speaker_mapping[speaker_name] for speaker_name in speaker_names]
speaker_c = torch.LongTensor(speaker_c)
else:
speaker_c = None
@ -112,13 +112,22 @@ def format_data(data):
speaker_c = speaker_c.cuda(non_blocking=True)
if attn_mask is not None:
attn_mask = attn_mask.cuda(non_blocking=True)
return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
avg_text_length, avg_spec_length, attn_mask, item_idx
return (
text_input,
text_lengths,
mel_input,
mel_lengths,
speaker_c,
avg_text_length,
avg_spec_length,
attn_mask,
item_idx,
)
def data_depended_init(data_loader, model):
"""Data depended initialization for activation normalization."""
if hasattr(model, 'module'):
if hasattr(model, "module"):
for f in model.module.decoder.flows:
if getattr(f, "set_ddi", False):
f.set_ddi(True)
@ -134,17 +143,15 @@ def data_depended_init(data_loader, model):
for _, data in enumerate(data_loader):
# format data
text_input, text_lengths, mel_input, mel_lengths, spekaer_embed,\
_, _, attn_mask, _ = format_data(data)
text_input, text_lengths, mel_input, mel_lengths, spekaer_embed, _, _, attn_mask, _ = format_data(data)
# forward pass model
_ = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=spekaer_embed)
_ = model.forward(text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=spekaer_embed)
if num_iter == c.data_dep_init_iter:
break
num_iter += 1
if hasattr(model, 'module'):
if hasattr(model, "module"):
for f in model.module.decoder.flows:
if getattr(f, "set_ddi", False):
f.set_ddi(False)
@ -155,15 +162,13 @@ def data_depended_init(data_loader, model):
return model
def train(data_loader, model, criterion, optimizer, scheduler,
ap, global_step, epoch):
def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch):
model.train()
epoch_time = 0
keep_avg = KeepAverage()
if use_cuda:
batch_n_iter = int(
len(data_loader.dataset) / (c.batch_size * num_gpus))
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
else:
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
end_time = time.time()
@ -173,8 +178,17 @@ def train(data_loader, model, criterion, optimizer, scheduler,
start_time = time.time()
# format data
text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
avg_text_length, avg_spec_length, attn_mask, _ = format_data(data)
(
text_input,
text_lengths,
mel_input,
mel_lengths,
speaker_c,
avg_text_length,
avg_spec_length,
attn_mask,
_,
) = format_data(data)
loader_time = time.time() - end_time
@ -184,24 +198,22 @@ def train(data_loader, model, criterion, optimizer, scheduler,
# forward pass model
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c)
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c
)
# compute loss
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
o_dur_log, o_total_dur, text_lengths)
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, o_dur_log, o_total_dur, text_lengths)
# backward pass with loss scaling
if c.mixed_precision:
scaler.scale(loss_dict['loss']).backward()
scaler.scale(loss_dict["loss"]).backward()
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.grad_clip)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
scaler.step(optimizer)
scaler.update()
else:
loss_dict['loss'].backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.grad_clip)
loss_dict["loss"].backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
optimizer.step()
# setup lr
@ -209,20 +221,20 @@ def train(data_loader, model, criterion, optimizer, scheduler,
scheduler.step()
# current_lr
current_lr = optimizer.param_groups[0]['lr']
current_lr = optimizer.param_groups[0]["lr"]
# compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
loss_dict['align_error'] = align_error
loss_dict["align_error"] = align_error
step_time = time.time() - start_time
epoch_time += step_time
# aggregate losses from processes
if num_gpus > 1:
loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus)
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
loss_dict["log_mle"] = reduce_tensor(loss_dict["log_mle"].data, num_gpus)
loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus)
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
# detach loss values
loss_dict_new = dict()
@ -236,9 +248,9 @@ def train(data_loader, model, criterion, optimizer, scheduler,
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
update_train_values['avg_' + key] = value
update_train_values['avg_loader_time'] = loader_time
update_train_values['avg_step_time'] = step_time
update_train_values["avg_" + key] = value
update_train_values["avg_loader_time"] = loader_time
update_train_values["avg_step_time"] = step_time
keep_avg.update_values(update_train_values)
# print training progress
@ -250,26 +262,29 @@ def train(data_loader, model, criterion, optimizer, scheduler,
"loader_time": [loader_time, 2],
"current_lr": current_lr,
}
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
log_dict, loss_dict, keep_avg.avg_values)
c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values)
if args.rank == 0:
# Plot Training Iter Stats
# reduce TB load
if global_step % c.tb_plot_step == 0:
iter_stats = {
"lr": current_lr,
"grad_norm": grad_norm,
"step_time": step_time
}
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
if global_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters,
model_loss=loss_dict['loss'])
save_checkpoint(
model,
optimizer,
global_step,
epoch,
1,
OUT_PATH,
model_characters,
model_loss=loss_dict["loss"],
)
# wait all kernels to be completed
torch.cuda.synchronize()
@ -278,7 +293,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
# direct pass on model for spec predictions
target_speaker = None if speaker_c is None else speaker_c[:1]
if hasattr(model, 'module'):
if hasattr(model, "module"):
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker)
else:
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker)
@ -299,9 +314,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
# Sample audio
train_audio = ap.inv_melspectrogram(const_spec.T)
tb_logger.tb_train_audios(global_step,
{'TrainAudio': train_audio},
c.audio["sample_rate"])
tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, c.audio["sample_rate"])
end_time = time.time()
# print epoch stats
@ -328,16 +341,15 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
start_time = time.time()
# format data
text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
_, _, attn_mask, _ = format_data(data)
text_input, text_lengths, mel_input, mel_lengths, speaker_c, _, _, attn_mask, _ = format_data(data)
# forward pass model
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c)
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c
)
# compute loss
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
o_dur_log, o_total_dur, text_lengths)
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, o_dur_log, o_total_dur, text_lengths)
# step time
step_time = time.time() - start_time
@ -345,13 +357,13 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
# compute alignment score
align_error = 1 - alignment_diagonal_score(alignments)
loss_dict['align_error'] = align_error
loss_dict["align_error"] = align_error
# aggregate losses from processes
if num_gpus > 1:
loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus)
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
loss_dict["log_mle"] = reduce_tensor(loss_dict["log_mle"].data, num_gpus)
loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus)
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
# detach loss values
loss_dict_new = dict()
@ -365,7 +377,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
update_train_values['avg_' + key] = value
update_train_values["avg_" + key] = value
keep_avg.update_values(update_train_values)
if c.print_eval:
@ -375,7 +387,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
# Diagnostic visualizations
# direct pass on model for spec predictions
target_speaker = None if speaker_c is None else speaker_c[:1]
if hasattr(model, 'module'):
if hasattr(model, "module"):
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker)
else:
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker)
@ -389,13 +401,12 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
eval_figures = {
"prediction": plot_spectrogram(const_spec, ap),
"ground_truth": plot_spectrogram(gt_spec, ap),
"alignment": plot_alignment(align_img)
"alignment": plot_alignment(align_img),
}
# Sample audio
eval_audio = ap.inv_melspectrogram(const_spec.T)
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
c.audio["sample_rate"])
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"])
# Plot Validation Stats
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
@ -408,7 +419,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963."
"Prior to November 22, 1963.",
]
else:
with open(c.test_sentences_file, "r") as f:
@ -420,7 +431,9 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
print(" | > Synthesizing test sentences")
if c.use_speaker_embedding:
if c.use_external_speaker_embedding_file:
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding']
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]][
"embedding"
]
speaker_id = None
else:
speaker_id = 0
@ -442,25 +455,22 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
speaker_embedding=speaker_embedding,
style_wav=style_wav,
truncated=False,
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
enable_eos_bos_chars=c.enable_eos_bos_chars, # pylint: disable=unused-argument
use_griffin_lim=True,
do_trim_silence=False)
do_trim_silence=False,
)
file_path = os.path.join(AUDIO_PATH, str(global_step))
os.makedirs(file_path, exist_ok=True)
file_path = os.path.join(file_path,
"TestSentence_{}.wav".format(idx))
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
ap.save_wav(wav, file_path)
test_audios['{}-audio'.format(idx)] = wav
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
postnet_output, ap)
test_figures['{}-alignment'.format(idx)] = plot_alignment(
alignment)
except: #pylint: disable=bare-except
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap)
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment)
except: # pylint: disable=bare-except
print(" !! Error creating Test Sentence -", idx)
traceback.print_exc()
tb_logger.tb_test_audios(global_step, test_audios,
c.audio['sample_rate'])
tb_logger.tb_test_audios(global_step, test_audios, c.audio["sample_rate"])
tb_logger.tb_test_figures(global_step, test_figures)
return keep_avg.avg_values
@ -470,13 +480,12 @@ def main(args): # pylint: disable=redefined-outer-name
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
# Audio processor
ap = AudioProcessor(**c.audio)
if 'characters' in c.keys():
if "characters" in c.keys():
symbols, phonemes = make_symbols(**c.characters)
# DISTRUBUTED
if num_gpus > 1:
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"])
# set model characters
model_characters = phonemes if c.use_phonemes else symbols
@ -486,10 +495,10 @@ def main(args): # pylint: disable=redefined-outer-name
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
# set the portion of the data used for training
if 'train_portion' in c.keys():
meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)]
if 'eval_portion' in c.keys():
meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
if "train_portion" in c.keys():
meta_data_train = meta_data_train[: int(len(meta_data_train) * c.train_portion)]
if "eval_portion" in c.keys():
meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * c.eval_portion)]
# parse speakers
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH)
@ -501,26 +510,25 @@ def main(args): # pylint: disable=redefined-outer-name
if args.restore_path:
print(f" > Restoring from {os.path.basename(args.restore_path)} ...")
checkpoint = torch.load(args.restore_path, map_location='cpu')
checkpoint = torch.load(args.restore_path, map_location="cpu")
try:
# TODO: fix optimizer init, model.cuda() needs to be called before
# optimizer restore
optimizer.load_state_dict(checkpoint['optimizer'])
optimizer.load_state_dict(checkpoint["optimizer"])
if c.reinit_layers:
raise RuntimeError
model.load_state_dict(checkpoint['model'])
except: #pylint: disable=bare-except
model.load_state_dict(checkpoint["model"])
except: # pylint: disable=bare-except
print(" > Partial model initialization.")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
model.load_state_dict(model_dict)
del model_dict
for group in optimizer.param_groups:
group['initial_lr'] = c.lr
print(f" > Model restored from step {checkpoint['step']:d}",
flush=True)
args.restore_step = checkpoint['step']
group["initial_lr"] = c.lr
print(f" > Model restored from step {checkpoint['step']:d}", flush=True)
args.restore_step = checkpoint["step"]
else:
args.restore_step = 0
@ -533,9 +541,7 @@ def main(args): # pylint: disable=redefined-outer-name
model = DDP_th(model, device_ids=[args.rank])
if c.noam_schedule:
scheduler = NoamLR(optimizer,
warmup_steps=c.warmup_steps,
last_epoch=args.restore_step - 1)
scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
else:
scheduler = None
@ -543,16 +549,14 @@ def main(args): # pylint: disable=redefined-outer-name
print("\n > Model has {} parameters".format(num_params), flush=True)
if args.restore_step == 0 or not args.best_path:
best_loss = float('inf')
best_loss = float("inf")
print(" > Starting with inf best loss.")
else:
print(" > Restoring best loss from "
f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path,
map_location='cpu')['model_loss']
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
print(f" > Starting with loaded last best loss {best_loss}.")
keep_all_best = c.get('keep_all_best', False)
keep_after = c.get('keep_after', 10000) # void if keep_all_best False
keep_all_best = c.get("keep_all_best", False)
keep_after = c.get("keep_after", 10000) # void if keep_all_best False
# define dataloaders
train_loader = setup_loader(ap, 1, is_val=False, verbose=True)
@ -562,25 +566,32 @@ def main(args): # pylint: disable=redefined-outer-name
model = data_depended_init(train_loader, model)
for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs)
train_avg_loss_dict, global_step = train(train_loader, model,
criterion, optimizer,
scheduler, ap, global_step,
epoch)
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
global_step, epoch)
train_avg_loss_dict, global_step = train(
train_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch
)
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_loss']
target_loss = train_avg_loss_dict["avg_loss"]
if c.run_eval:
target_loss = eval_avg_loss_dict['avg_loss']
best_loss = save_best_model(target_loss, best_loss, model, optimizer,
global_step, epoch, c.r, OUT_PATH, model_characters,
keep_all_best=keep_all_best, keep_after=keep_after)
target_loss = eval_avg_loss_dict["avg_loss"]
best_loss = save_best_model(
target_loss,
best_loss,
model,
optimizer,
global_step,
epoch,
c.r,
OUT_PATH,
model_characters,
keep_all_best=keep_all_best,
keep_after=keep_after,
)
if __name__ == '__main__':
if __name__ == "__main__":
args = parse_arguments(sys.argv)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
args, model_class='tts')
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts")
try:
main(args)

View File

@ -10,6 +10,7 @@ from random import randrange
import torch
from TTS.utils.arguments import parse_arguments, process_args
# DISTRIBUTED
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.utils.data import DataLoader
@ -26,8 +27,7 @@ from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import init_distributed, reduce_tensor
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
remove_experiment_folder, set_init_dict)
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.radam import RAdam
from TTS.utils.training import NoamLR, setup_torch_training_env
@ -44,10 +44,9 @@ def setup_loader(ap, r, is_val=False, verbose=False):
compute_linear_spec=False,
meta_data=meta_data_eval if is_val else meta_data_train,
ap=ap,
tp=c.characters if 'characters' in c.keys() else None,
add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
batch_group_size=0 if is_val else c.batch_group_size *
c.batch_size,
tp=c.characters if "characters" in c.keys() else None,
add_blank=c["add_blank"] if "add_blank" in c.keys() else False,
batch_group_size=0 if is_val else c.batch_group_size * c.batch_size,
min_seq_len=c.min_seq_len,
max_seq_len=c.max_seq_len,
phoneme_cache_path=c.phoneme_cache_path,
@ -56,7 +55,10 @@ def setup_loader(ap, r, is_val=False, verbose=False):
enable_eos_bos=c.enable_eos_bos_chars,
use_noise_augment=not is_val,
verbose=verbose,
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
speaker_mapping=speaker_mapping
if c.use_speaker_embedding and c.use_external_speaker_embedding_file
else None,
)
if c.use_phonemes and c.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths.
@ -71,9 +73,9 @@ def setup_loader(ap, r, is_val=False, verbose=False):
collate_fn=dataset.collate_fn,
drop_last=False,
sampler=sampler,
num_workers=c.num_val_loader_workers
if is_val else c.num_loader_workers,
pin_memory=False)
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
pin_memory=False,
)
return loader
@ -95,9 +97,7 @@ def format_data(data):
speaker_c = data[8]
else:
# return speaker_id to be used by an embedding layer
speaker_c = [
speaker_mapping[speaker_name] for speaker_name in speaker_names
]
speaker_c = [speaker_mapping[speaker_name] for speaker_name in speaker_names]
speaker_c = torch.LongTensor(speaker_c)
else:
speaker_c = None
@ -105,7 +105,7 @@ def format_data(data):
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
for idx, am in enumerate(attn_mask):
# compute raw durations
c_idxs = am[:, :text_lengths[idx], :mel_lengths[idx]].max(1)[1]
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
@ -115,8 +115,10 @@ def format_data(data):
extra_frames = dur.sum() - mel_lengths[idx]
largest_idxs = torch.argsort(-dur)[:extra_frames]
dur[largest_idxs] -= 1
assert dur.sum() == mel_lengths[idx], f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
durations[idx, :text_lengths[idx]] = dur
assert (
dur.sum() == mel_lengths[idx]
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
durations[idx, : text_lengths[idx]] = dur
# dispatch data to GPU
if use_cuda:
text_input = text_input.cuda(non_blocking=True)
@ -127,19 +129,27 @@ def format_data(data):
speaker_c = speaker_c.cuda(non_blocking=True)
attn_mask = attn_mask.cuda(non_blocking=True)
durations = durations.cuda(non_blocking=True)
return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
avg_text_length, avg_spec_length, attn_mask, durations, item_idx
return (
text_input,
text_lengths,
mel_input,
mel_lengths,
speaker_c,
avg_text_length,
avg_spec_length,
attn_mask,
durations,
item_idx,
)
def train(data_loader, model, criterion, optimizer, scheduler,
ap, global_step, epoch):
def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch):
model.train()
epoch_time = 0
keep_avg = KeepAverage()
if use_cuda:
batch_n_iter = int(
len(data_loader.dataset) / (c.batch_size * num_gpus))
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
else:
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
end_time = time.time()
@ -149,8 +159,18 @@ def train(data_loader, model, criterion, optimizer, scheduler,
start_time = time.time()
# format data
text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\
avg_text_length, avg_spec_length, _, dur_target, _ = format_data(data)
(
text_input,
text_lengths,
mel_targets,
mel_lengths,
speaker_c,
avg_text_length,
avg_spec_length,
_,
dur_target,
_,
) = format_data(data)
loader_time = time.time() - end_time
@ -160,23 +180,24 @@ def train(data_loader, model, criterion, optimizer, scheduler,
# forward pass model
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
decoder_output, dur_output, alignments = model.forward(
text_input, text_lengths, mel_lengths, dur_target, g=speaker_c)
text_input, text_lengths, mel_lengths, dur_target, g=speaker_c
)
# compute loss
loss_dict = criterion(decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths)
loss_dict = criterion(
decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths
)
# backward pass with loss scaling
if c.mixed_precision:
scaler.scale(loss_dict['loss']).backward()
scaler.scale(loss_dict["loss"]).backward()
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.grad_clip)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
scaler.step(optimizer)
scaler.update()
else:
loss_dict['loss'].backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.grad_clip)
loss_dict["loss"].backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
optimizer.step()
# setup lr
@ -184,21 +205,21 @@ def train(data_loader, model, criterion, optimizer, scheduler,
scheduler.step()
# current_lr
current_lr = optimizer.param_groups[0]['lr']
current_lr = optimizer.param_groups[0]["lr"]
# compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
loss_dict['align_error'] = align_error
loss_dict["align_error"] = align_error
step_time = time.time() - start_time
epoch_time += step_time
# aggregate losses from processes
if num_gpus > 1:
loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, num_gpus)
loss_dict['loss_ssim'] = reduce_tensor(loss_dict['loss_ssim'].data, num_gpus)
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus)
loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus)
loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus)
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
# detach loss values
loss_dict_new = dict()
@ -212,41 +233,43 @@ def train(data_loader, model, criterion, optimizer, scheduler,
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
update_train_values['avg_' + key] = value
update_train_values['avg_loader_time'] = loader_time
update_train_values['avg_step_time'] = step_time
update_train_values["avg_" + key] = value
update_train_values["avg_loader_time"] = loader_time
update_train_values["avg_step_time"] = step_time
keep_avg.update_values(update_train_values)
# print training progress
if global_step % c.print_step == 0:
log_dict = {
"avg_spec_length": [avg_spec_length, 1], # value, precision
"avg_text_length": [avg_text_length, 1],
"step_time": [step_time, 4],
"loader_time": [loader_time, 2],
"current_lr": current_lr,
}
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
log_dict, loss_dict, keep_avg.avg_values)
c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values)
if args.rank == 0:
# Plot Training Iter Stats
# reduce TB load
if global_step % c.tb_plot_step == 0:
iter_stats = {
"lr": current_lr,
"grad_norm": grad_norm,
"step_time": step_time
}
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
if global_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters,
model_loss=loss_dict['loss'])
save_checkpoint(
model,
optimizer,
global_step,
epoch,
1,
OUT_PATH,
model_characters,
model_loss=loss_dict["loss"],
)
# wait all kernels to be completed
torch.cuda.synchronize()
@ -267,9 +290,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
# Sample audio
train_audio = ap.inv_melspectrogram(pred_spec.T)
tb_logger.tb_train_audios(global_step,
{'TrainAudio': train_audio},
c.audio["sample_rate"])
tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, c.audio["sample_rate"])
end_time = time.time()
# print epoch stats
@ -296,16 +317,18 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
start_time = time.time()
# format data
text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\
_, _, _, dur_target, _ = format_data(data)
text_input, text_lengths, mel_targets, mel_lengths, speaker_c, _, _, _, dur_target, _ = format_data(data)
# forward pass model
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
decoder_output, dur_output, alignments = model.forward(
text_input, text_lengths, mel_lengths, dur_target, g=speaker_c)
text_input, text_lengths, mel_lengths, dur_target, g=speaker_c
)
# compute loss
loss_dict = criterion(decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths)
loss_dict = criterion(
decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths
)
# step time
step_time = time.time() - start_time
@ -313,14 +336,14 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
# compute alignment score
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
loss_dict['align_error'] = align_error
loss_dict["align_error"] = align_error
# aggregate losses from processes
if num_gpus > 1:
loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, num_gpus)
loss_dict['loss_ssim'] = reduce_tensor(loss_dict['loss_ssim'].data, num_gpus)
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus)
loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus)
loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus)
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
# detach loss values
loss_dict_new = dict()
@ -334,7 +357,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
update_train_values['avg_' + key] = value
update_train_values["avg_" + key] = value
keep_avg.update_values(update_train_values)
if c.print_eval:
@ -350,13 +373,12 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
eval_figures = {
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
"alignment": plot_alignment(align_img, output_fig=False)
"alignment": plot_alignment(align_img, output_fig=False),
}
# Sample audio
eval_audio = ap.inv_melspectrogram(pred_spec.T)
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
c.audio["sample_rate"])
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"])
# Plot Validation Stats
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
@ -369,7 +391,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963."
"Prior to November 22, 1963.",
]
else:
with open(c.test_sentences_file, "r") as f:
@ -381,7 +403,9 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
print(" | > Synthesizing test sentences")
if c.use_speaker_embedding:
if c.use_external_speaker_embedding_file:
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding']
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]][
"embedding"
]
speaker_id = None
else:
speaker_id = 0
@ -403,25 +427,22 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
speaker_embedding=speaker_embedding,
style_wav=style_wav,
truncated=False,
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
enable_eos_bos_chars=c.enable_eos_bos_chars, # pylint: disable=unused-argument
use_griffin_lim=True,
do_trim_silence=False)
do_trim_silence=False,
)
file_path = os.path.join(AUDIO_PATH, str(global_step))
os.makedirs(file_path, exist_ok=True)
file_path = os.path.join(file_path,
"TestSentence_{}.wav".format(idx))
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
ap.save_wav(wav, file_path)
test_audios['{}-audio'.format(idx)] = wav
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
postnet_output, ap)
test_figures['{}-alignment'.format(idx)] = plot_alignment(
alignment)
except: #pylint: disable=bare-except
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap)
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment)
except: # pylint: disable=bare-except
print(" !! Error creating Test Sentence -", idx)
traceback.print_exc()
tb_logger.tb_test_audios(global_step, test_audios,
c.audio['sample_rate'])
tb_logger.tb_test_audios(global_step, test_audios, c.audio["sample_rate"])
tb_logger.tb_test_figures(global_step, test_figures)
return keep_avg.avg_values
@ -432,13 +453,12 @@ def main(args): # pylint: disable=redefined-outer-name
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
# Audio processor
ap = AudioProcessor(**c.audio)
if 'characters' in c.keys():
if "characters" in c.keys():
symbols, phonemes = make_symbols(**c.characters)
# DISTRUBUTED
if num_gpus > 1:
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"])
# set model characters
model_characters = phonemes if c.use_phonemes else symbols
@ -448,10 +468,10 @@ def main(args): # pylint: disable=redefined-outer-name
meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True)
# set the portion of the data used for training if set in config.json
if 'train_portion' in c.keys():
meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)]
if 'eval_portion' in c.keys():
meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
if "train_portion" in c.keys():
meta_data_train = meta_data_train[: int(len(meta_data_train) * c.train_portion)]
if "eval_portion" in c.keys():
meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * c.eval_portion)]
# parse speakers
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH)
@ -463,26 +483,25 @@ def main(args): # pylint: disable=redefined-outer-name
if args.restore_path:
print(f" > Restoring from {os.path.basename(args.restore_path)} ...")
checkpoint = torch.load(args.restore_path, map_location='cpu')
checkpoint = torch.load(args.restore_path, map_location="cpu")
try:
# TODO: fix optimizer init, model.cuda() needs to be called before
# optimizer restore
optimizer.load_state_dict(checkpoint['optimizer'])
optimizer.load_state_dict(checkpoint["optimizer"])
if c.reinit_layers:
raise RuntimeError
model.load_state_dict(checkpoint['model'])
except: #pylint: disable=bare-except
model.load_state_dict(checkpoint["model"])
except: # pylint: disable=bare-except
print(" > Partial model initialization.")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
model.load_state_dict(model_dict)
del model_dict
for group in optimizer.param_groups:
group['initial_lr'] = c.lr
print(" > Model restored from step %d" % checkpoint['step'],
flush=True)
args.restore_step = checkpoint['step']
group["initial_lr"] = c.lr
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
args.restore_step = checkpoint["step"]
else:
args.restore_step = 0
@ -495,9 +514,7 @@ def main(args): # pylint: disable=redefined-outer-name
model = DDP_th(model, device_ids=[args.rank])
if c.noam_schedule:
scheduler = NoamLR(optimizer,
warmup_steps=c.warmup_steps,
last_epoch=args.restore_step - 1)
scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
else:
scheduler = None
@ -505,16 +522,14 @@ def main(args): # pylint: disable=redefined-outer-name
print("\n > Model has {} parameters".format(num_params), flush=True)
if args.restore_step == 0 or not args.best_path:
best_loss = float('inf')
best_loss = float("inf")
print(" > Starting with inf best loss.")
else:
print(" > Restoring best loss from "
f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path,
map_location='cpu')['model_loss']
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
print(f" > Starting with loaded last best loss {best_loss}.")
keep_all_best = c.get('keep_all_best', False)
keep_after = c.get('keep_after', 10000) # void if keep_all_best False
keep_all_best = c.get("keep_all_best", False)
keep_after = c.get("keep_after", 10000) # void if keep_all_best False
# define dataloaders
train_loader = setup_loader(ap, 1, is_val=False, verbose=True)
@ -523,24 +538,32 @@ def main(args): # pylint: disable=redefined-outer-name
global_step = args.restore_step
for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs)
train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer,
scheduler, ap, global_step,
epoch)
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
global_step, epoch)
train_avg_loss_dict, global_step = train(
train_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch
)
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_loss']
target_loss = train_avg_loss_dict["avg_loss"]
if c.run_eval:
target_loss = eval_avg_loss_dict['avg_loss']
best_loss = save_best_model(target_loss, best_loss, model, optimizer,
global_step, epoch, c.r, OUT_PATH, model_characters,
keep_all_best=keep_all_best, keep_after=keep_after)
target_loss = eval_avg_loss_dict["avg_loss"]
best_loss = save_best_model(
target_loss,
best_loss,
model,
optimizer,
global_step,
epoch,
c.r,
OUT_PATH,
model_characters,
keep_all_best=keep_all_best,
keep_after=keep_after,
)
if __name__ == '__main__':
if __name__ == "__main__":
args = parse_arguments(sys.argv)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
args, model_class='tts')
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts")
try:
main(args)

View File

@ -22,14 +22,17 @@ from TTS.tts.utils.synthesis import synthesis
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce,
init_distributed, reduce_tensor)
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
remove_experiment_folder, set_init_dict)
from TTS.utils.distribute import DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.radam import RAdam
from TTS.utils.training import (NoamLR, adam_weight_decay, check_update,
gradual_training_scheduler, set_weight_decay,
setup_torch_training_env)
from TTS.utils.training import (
NoamLR,
adam_weight_decay,
check_update,
gradual_training_scheduler,
set_weight_decay,
setup_torch_training_env,
)
use_cuda, num_gpus = setup_torch_training_env(True, False)
@ -42,13 +45,12 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
dataset = MyDataset(
r,
c.text_cleaner,
compute_linear_spec=c.model.lower() == 'tacotron',
compute_linear_spec=c.model.lower() == "tacotron",
meta_data=meta_data_eval if is_val else meta_data_train,
ap=ap,
tp=c.characters if 'characters' in c.keys() else None,
add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
batch_group_size=0 if is_val else c.batch_group_size *
c.batch_size,
tp=c.characters if "characters" in c.keys() else None,
add_blank=c["add_blank"] if "add_blank" in c.keys() else False,
batch_group_size=0 if is_val else c.batch_group_size * c.batch_size,
min_seq_len=c.min_seq_len,
max_seq_len=c.max_seq_len,
phoneme_cache_path=c.phoneme_cache_path,
@ -56,11 +58,10 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
phoneme_language=c.phoneme_language,
enable_eos_bos=c.enable_eos_bos_chars,
verbose=verbose,
speaker_mapping=(speaker_mapping if (
c.use_speaker_embedding
and c.use_external_speaker_embedding_file
) else None)
)
speaker_mapping=(
speaker_mapping if (c.use_speaker_embedding and c.use_external_speaker_embedding_file) else None
),
)
if c.use_phonemes and c.compute_input_seq_cache:
# precompute phonemes to have a better estimate of sequence lengths.
@ -75,11 +76,12 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
collate_fn=dataset.collate_fn,
drop_last=False,
sampler=sampler,
num_workers=c.num_val_loader_workers
if is_val else c.num_loader_workers,
pin_memory=False)
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
pin_memory=False,
)
return loader
def format_data(data):
# setup input data
text_input = data[0]
@ -97,21 +99,16 @@ def format_data(data):
speaker_embeddings = data[8]
speaker_ids = None
else:
speaker_ids = [
speaker_mapping[speaker_name] for speaker_name in speaker_names
]
speaker_ids = [speaker_mapping[speaker_name] for speaker_name in speaker_names]
speaker_ids = torch.LongTensor(speaker_ids)
speaker_embeddings = None
else:
speaker_embeddings = None
speaker_ids = None
# set stop targets view, we predict a single stop token per iteration.
stop_targets = stop_targets.view(text_input.shape[0],
stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) >
0.0).unsqueeze(2).float().squeeze(2)
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1)
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
# dispatch data to GPU
if use_cuda:
@ -126,17 +123,26 @@ def format_data(data):
if speaker_embeddings is not None:
speaker_embeddings = speaker_embeddings.cuda(non_blocking=True)
return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length
return (
text_input,
text_lengths,
mel_input,
mel_lengths,
linear_input,
stop_targets,
speaker_ids,
speaker_embeddings,
max_text_length,
max_spec_length,
)
def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
ap, global_step, epoch, scaler, scaler_st):
def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap, global_step, epoch, scaler, scaler_st):
model.train()
epoch_time = 0
keep_avg = KeepAverage()
if use_cuda:
batch_n_iter = int(
len(data_loader.dataset) / (c.batch_size * num_gpus))
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
else:
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
end_time = time.time()
@ -145,7 +151,18 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
start_time = time.time()
# format data
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length = format_data(data)
(
text_input,
text_lengths,
mel_input,
mel_lengths,
linear_input,
stop_targets,
speaker_ids,
speaker_embeddings,
max_text_length,
max_spec_length,
) = format_data(data)
loader_time = time.time() - end_time
global_step += 1
@ -161,35 +178,65 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
# forward pass model
if c.bidirectional_decoder or c.double_decoder_consistency:
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
(
decoder_output,
postnet_output,
alignments,
stop_tokens,
decoder_backward_output,
alignments_backward,
) = model(
text_input,
text_lengths,
mel_input,
mel_lengths,
speaker_ids=speaker_ids,
speaker_embeddings=speaker_embeddings,
)
else:
decoder_output, postnet_output, alignments, stop_tokens = model(
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
text_input,
text_lengths,
mel_input,
mel_lengths,
speaker_ids=speaker_ids,
speaker_embeddings=speaker_embeddings,
)
decoder_backward_output = None
alignments_backward = None
# set the [alignment] lengths wrt reduction factor for guided attention
if mel_lengths.max() % model.decoder.r != 0:
alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r
alignment_lengths = (
mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))
) // model.decoder.r
else:
alignment_lengths = mel_lengths // model.decoder.r
alignment_lengths = mel_lengths // model.decoder.r
# compute loss
loss_dict = criterion(postnet_output, decoder_output, mel_input,
linear_input, stop_tokens, stop_targets,
mel_lengths, decoder_backward_output,
alignments, alignment_lengths,
alignments_backward, text_lengths)
loss_dict = criterion(
postnet_output,
decoder_output,
mel_input,
linear_input,
stop_tokens,
stop_targets,
mel_lengths,
decoder_backward_output,
alignments,
alignment_lengths,
alignments_backward,
text_lengths,
)
# check nan loss
if torch.isnan(loss_dict['loss']).any():
raise RuntimeError(f'Detected NaN loss at step {global_step}.')
if torch.isnan(loss_dict["loss"]).any():
raise RuntimeError(f"Detected NaN loss at step {global_step}.")
# optimizer step
if c.mixed_precision:
# model optimizer step in mixed precision mode
scaler.scale(loss_dict['loss']).backward()
scaler.scale(loss_dict["loss"]).backward()
scaler.unscale_(optimizer)
optimizer, current_lr = adam_weight_decay(optimizer)
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
@ -198,7 +245,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
# stopnet optimizer step
if c.separate_stopnet:
scaler_st.scale(loss_dict['stopnet_loss']).backward()
scaler_st.scale(loss_dict["stopnet_loss"]).backward()
scaler.unscale_(optimizer_st)
optimizer_st, _ = adam_weight_decay(optimizer_st)
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
@ -208,14 +255,14 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
grad_norm_st = 0
else:
# main model optimizer step
loss_dict['loss'].backward()
loss_dict["loss"].backward()
optimizer, current_lr = adam_weight_decay(optimizer)
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
optimizer.step()
# stopnet optimizer step
if c.separate_stopnet:
loss_dict['stopnet_loss'].backward()
loss_dict["stopnet_loss"].backward()
optimizer_st, _ = adam_weight_decay(optimizer_st)
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
optimizer_st.step()
@ -224,17 +271,19 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
# compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(alignments)
loss_dict['align_error'] = align_error
loss_dict["align_error"] = align_error
step_time = time.time() - start_time
epoch_time += step_time
# aggregate losses from processes
if num_gpus > 1:
loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus)
loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus)
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus) if c.stopnet else loss_dict['stopnet_loss']
loss_dict["postnet_loss"] = reduce_tensor(loss_dict["postnet_loss"].data, num_gpus)
loss_dict["decoder_loss"] = reduce_tensor(loss_dict["decoder_loss"].data, num_gpus)
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
loss_dict["stopnet_loss"] = (
reduce_tensor(loss_dict["stopnet_loss"].data, num_gpus) if c.stopnet else loss_dict["stopnet_loss"]
)
# detach loss values
loss_dict_new = dict()
@ -248,9 +297,9 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
update_train_values['avg_' + key] = value
update_train_values['avg_loader_time'] = loader_time
update_train_values['avg_step_time'] = step_time
update_train_values["avg_" + key] = value
update_train_values["avg_loader_time"] = loader_time
update_train_values["avg_step_time"] = step_time
keep_avg.update_values(update_train_values)
# print training progress
@ -262,8 +311,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
"loader_time": [loader_time, 2],
"current_lr": current_lr,
}
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
log_dict, loss_dict, keep_avg.avg_values)
c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values)
if args.rank == 0:
# Plot Training Iter Stats
@ -273,7 +321,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
"lr": current_lr,
"grad_norm": grad_norm,
"grad_norm_st": grad_norm_st,
"step_time": step_time
"step_time": step_time,
}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
@ -281,17 +329,26 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
if global_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
optimizer_st=optimizer_st,
model_loss=loss_dict['postnet_loss'],
characters=model_characters,
scaler=scaler.state_dict() if c.mixed_precision else None)
save_checkpoint(
model,
optimizer,
global_step,
epoch,
model.decoder.r,
OUT_PATH,
optimizer_st=optimizer_st,
model_loss=loss_dict["postnet_loss"],
characters=model_characters,
scaler=scaler.state_dict() if c.mixed_precision else None,
)
# Diagnostic visualizations
const_spec = postnet_output[0].data.cpu().numpy()
gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
"Tacotron", "TacotronGST"
] else mel_input[0].data.cpu().numpy()
gt_spec = (
linear_input[0].data.cpu().numpy()
if c.model in ["Tacotron", "TacotronGST"]
else mel_input[0].data.cpu().numpy()
)
align_img = alignments[0].data.cpu().numpy()
figures = {
@ -301,7 +358,9 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
}
if c.bidirectional_decoder or c.double_decoder_consistency:
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
figures["alignment_backward"] = plot_alignment(
alignments_backward[0].data.cpu().numpy(), output_fig=False
)
tb_logger.tb_train_figures(global_step, figures)
@ -310,9 +369,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
train_audio = ap.inv_spectrogram(const_spec.T)
else:
train_audio = ap.inv_melspectrogram(const_spec.T)
tb_logger.tb_train_audios(global_step,
{'TrainAudio': train_audio},
c.audio["sample_rate"])
tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, c.audio["sample_rate"])
end_time = time.time()
# print epoch stats
@ -339,31 +396,62 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
start_time = time.time()
# format data
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, _, _ = format_data(data)
(
text_input,
text_lengths,
mel_input,
mel_lengths,
linear_input,
stop_targets,
speaker_ids,
speaker_embeddings,
_,
_,
) = format_data(data)
assert mel_input.shape[1] % model.decoder.r == 0
# forward pass model
if c.bidirectional_decoder or c.double_decoder_consistency:
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
(
decoder_output,
postnet_output,
alignments,
stop_tokens,
decoder_backward_output,
alignments_backward,
) = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings
)
else:
decoder_output, postnet_output, alignments, stop_tokens = model(
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings
)
decoder_backward_output = None
alignments_backward = None
# set the alignment lengths wrt reduction factor for guided attention
if mel_lengths.max() % model.decoder.r != 0:
alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r
alignment_lengths = (
mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))
) // model.decoder.r
else:
alignment_lengths = mel_lengths // model.decoder.r
alignment_lengths = mel_lengths // model.decoder.r
# compute loss
loss_dict = criterion(postnet_output, decoder_output, mel_input,
linear_input, stop_tokens, stop_targets,
mel_lengths, decoder_backward_output,
alignments, alignment_lengths, alignments_backward,
text_lengths)
loss_dict = criterion(
postnet_output,
decoder_output,
mel_input,
linear_input,
stop_tokens,
stop_targets,
mel_lengths,
decoder_backward_output,
alignments,
alignment_lengths,
alignments_backward,
text_lengths,
)
# step time
step_time = time.time() - start_time
@ -371,14 +459,14 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
# compute alignment score
align_error = 1 - alignment_diagonal_score(alignments)
loss_dict['align_error'] = align_error
loss_dict["align_error"] = align_error
# aggregate losses from processes
if num_gpus > 1:
loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus)
loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus)
loss_dict["postnet_loss"] = reduce_tensor(loss_dict["postnet_loss"].data, num_gpus)
loss_dict["decoder_loss"] = reduce_tensor(loss_dict["decoder_loss"].data, num_gpus)
if c.stopnet:
loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus)
loss_dict["stopnet_loss"] = reduce_tensor(loss_dict["stopnet_loss"].data, num_gpus)
# detach loss values
loss_dict_new = dict()
@ -392,7 +480,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
update_train_values['avg_' + key] = value
update_train_values["avg_" + key] = value
keep_avg.update_values(update_train_values)
if c.print_eval:
@ -402,15 +490,17 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
# Diagnostic visualizations
idx = np.random.randint(mel_input.shape[0])
const_spec = postnet_output[idx].data.cpu().numpy()
gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
"Tacotron", "TacotronGST"
] else mel_input[idx].data.cpu().numpy()
gt_spec = (
linear_input[idx].data.cpu().numpy()
if c.model in ["Tacotron", "TacotronGST"]
else mel_input[idx].data.cpu().numpy()
)
align_img = alignments[idx].data.cpu().numpy()
eval_figures = {
"prediction": plot_spectrogram(const_spec, ap, output_fig=False),
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
"alignment": plot_alignment(align_img, output_fig=False)
"alignment": plot_alignment(align_img, output_fig=False),
}
# Sample audio
@ -418,14 +508,13 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
eval_audio = ap.inv_spectrogram(const_spec.T)
else:
eval_audio = ap.inv_melspectrogram(const_spec.T)
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
c.audio["sample_rate"])
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"])
# Plot Validation Stats
if c.bidirectional_decoder or c.double_decoder_consistency:
align_b_img = alignments_backward[idx].data.cpu().numpy()
eval_figures['alignment2'] = plot_alignment(align_b_img, output_fig=False)
eval_figures["alignment2"] = plot_alignment(align_b_img, output_fig=False)
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
tb_logger.tb_eval_figures(global_step, eval_figures)
@ -436,7 +525,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
"Be a voice, not an echo.",
"I'm sorry Dave. I'm afraid I can't do that.",
"This cake is great. It's so delicious and moist.",
"Prior to November 22, 1963."
"Prior to November 22, 1963.",
]
else:
with open(c.test_sentences_file, "r") as f:
@ -447,13 +536,17 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
test_figures = {}
print(" | > Synthesizing test sentences")
speaker_id = 0 if c.use_speaker_embedding else None
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding'] if c.use_external_speaker_embedding_file and c.use_speaker_embedding else None
speaker_embedding = (
speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]]["embedding"]
if c.use_external_speaker_embedding_file and c.use_speaker_embedding
else None
)
style_wav = c.get("gst_style_input")
if style_wav is None and c.use_gst:
# inicialize GST with zero dict.
style_wav = {}
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
for i in range(c.gst['gst_style_tokens']):
for i in range(c.gst["gst_style_tokens"]):
style_wav[str(i)] = 0
style_wav = c.get("gst_style_input")
for idx, test_sentence in enumerate(test_sentences):
@ -468,25 +561,22 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
speaker_embedding=speaker_embedding,
style_wav=style_wav,
truncated=False,
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
enable_eos_bos_chars=c.enable_eos_bos_chars, # pylint: disable=unused-argument
use_griffin_lim=True,
do_trim_silence=False)
do_trim_silence=False,
)
file_path = os.path.join(AUDIO_PATH, str(global_step))
os.makedirs(file_path, exist_ok=True)
file_path = os.path.join(file_path,
"TestSentence_{}.wav".format(idx))
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
ap.save_wav(wav, file_path)
test_audios['{}-audio'.format(idx)] = wav
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
postnet_output, ap, output_fig=False)
test_figures['{}-alignment'.format(idx)] = plot_alignment(
alignment, output_fig=False)
except: #pylint: disable=bare-except
test_audios["{}-audio".format(idx)] = wav
test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap, output_fig=False)
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
except: # pylint: disable=bare-except
print(" !! Error creating Test Sentence -", idx)
traceback.print_exc()
tb_logger.tb_test_audios(global_step, test_audios,
c.audio['sample_rate'])
tb_logger.tb_test_audios(global_step, test_audios, c.audio["sample_rate"])
tb_logger.tb_test_figures(global_step, test_figures)
return keep_avg.avg_values
@ -498,13 +588,12 @@ def main(args): # pylint: disable=redefined-outer-name
ap = AudioProcessor(**c.audio)
# setup custom characters if set in config file.
if 'characters' in c.keys():
if "characters" in c.keys():
symbols, phonemes = make_symbols(**c.characters)
# DISTRUBUTED
if num_gpus > 1:
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"])
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model_characters = phonemes if c.use_phonemes else symbols
@ -512,10 +601,10 @@ def main(args): # pylint: disable=redefined-outer-name
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
# set the portion of the data used for training
if 'train_portion' in c.keys():
meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)]
if 'eval_portion' in c.keys():
meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
if "train_portion" in c.keys():
meta_data_train = meta_data_train[: int(len(meta_data_train) * c.train_portion)]
if "eval_portion" in c.keys():
meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * c.eval_portion)]
# parse speakers
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH)
@ -529,9 +618,7 @@ def main(args): # pylint: disable=redefined-outer-name
params = set_weight_decay(model, c.wd)
optimizer = RAdam(params, lr=c.lr, weight_decay=0)
if c.stopnet and c.separate_stopnet:
optimizer_st = RAdam(model.decoder.stopnet.parameters(),
lr=c.lr,
weight_decay=0)
optimizer_st = RAdam(model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
else:
optimizer_st = None
@ -539,13 +626,13 @@ def main(args): # pylint: disable=redefined-outer-name
criterion = TacotronLoss(c, stopnet_pos_weight=c.stopnet_pos_weight, ga_sigma=0.4)
if args.restore_path:
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
checkpoint = torch.load(args.restore_path, map_location='cpu')
checkpoint = torch.load(args.restore_path, map_location="cpu")
try:
print(" > Restoring Model...")
model.load_state_dict(checkpoint['model'])
model.load_state_dict(checkpoint["model"])
# optimizer restore
print(" > Restoring Optimizer...")
optimizer.load_state_dict(checkpoint['optimizer'])
optimizer.load_state_dict(checkpoint["optimizer"])
if "scaler" in checkpoint and c.mixed_precision:
print(" > Restoring AMP Scaler...")
scaler.load_state_dict(checkpoint["scaler"])
@ -554,17 +641,16 @@ def main(args): # pylint: disable=redefined-outer-name
except (KeyError, RuntimeError):
print(" > Partial model initialization...")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
# torch.save(model_dict, os.path.join(OUT_PATH, 'state_dict.pt'))
# print("State Dict saved for debug in: ", os.path.join(OUT_PATH, 'state_dict.pt'))
model.load_state_dict(model_dict)
del model_dict
for group in optimizer.param_groups:
group['lr'] = c.lr
print(" > Model restored from step %d" % checkpoint['step'],
flush=True)
args.restore_step = checkpoint['step']
group["lr"] = c.lr
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
args.restore_step = checkpoint["step"]
else:
args.restore_step = 0
@ -577,9 +663,7 @@ def main(args): # pylint: disable=redefined-outer-name
model = apply_gradient_allreduce(model)
if c.noam_schedule:
scheduler = NoamLR(optimizer,
warmup_steps=c.warmup_steps,
last_epoch=args.restore_step - 1)
scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
else:
scheduler = None
@ -587,22 +671,17 @@ def main(args): # pylint: disable=redefined-outer-name
print("\n > Model has {} parameters".format(num_params), flush=True)
if args.restore_step == 0 or not args.best_path:
best_loss = float('inf')
best_loss = float("inf")
print(" > Starting with inf best loss.")
else:
print(" > Restoring best loss from "
f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path,
map_location='cpu')['model_loss']
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
print(f" > Starting with loaded last best loss {best_loss}.")
keep_all_best = c.get('keep_all_best', False)
keep_after = c.get('keep_after', 10000) # void if keep_all_best False
keep_all_best = c.get("keep_all_best", False)
keep_after = c.get("keep_after", 10000) # void if keep_all_best False
# define data loaders
train_loader = setup_loader(ap,
model.decoder.r,
is_val=False,
verbose=True)
train_loader = setup_loader(ap, model.decoder.r, is_val=False, verbose=True)
eval_loader = setup_loader(ap, model.decoder.r, is_val=True)
global_step = args.restore_step
@ -617,28 +696,29 @@ def main(args): # pylint: disable=redefined-outer-name
model.decoder_backward.set_r(r)
train_loader.dataset.outputs_per_step = r
eval_loader.dataset.outputs_per_step = r
train_loader = setup_loader(ap,
model.decoder.r,
is_val=False,
dataset=train_loader.dataset)
eval_loader = setup_loader(ap,
model.decoder.r,
is_val=True,
dataset=eval_loader.dataset)
train_loader = setup_loader(ap, model.decoder.r, is_val=False, dataset=train_loader.dataset)
eval_loader = setup_loader(ap, model.decoder.r, is_val=True, dataset=eval_loader.dataset)
print("\n > Number of output frames:", model.decoder.r)
# train one epoch
train_avg_loss_dict, global_step = train(train_loader, model,
criterion, optimizer,
optimizer_st, scheduler, ap,
global_step, epoch, scaler,
scaler_st)
train_avg_loss_dict, global_step = train(
train_loader,
model,
criterion,
optimizer,
optimizer_st,
scheduler,
ap,
global_step,
epoch,
scaler,
scaler_st,
)
# eval one epoch
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
global_step, epoch)
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = train_avg_loss_dict['avg_postnet_loss']
target_loss = train_avg_loss_dict["avg_postnet_loss"]
if c.run_eval:
target_loss = eval_avg_loss_dict['avg_postnet_loss']
target_loss = eval_avg_loss_dict["avg_postnet_loss"]
best_loss = save_best_model(
target_loss,
best_loss,
@ -651,14 +731,13 @@ def main(args): # pylint: disable=redefined-outer-name
model_characters,
keep_all_best=keep_all_best,
keep_after=keep_after,
scaler=scaler.state_dict() if c.mixed_precision else None
scaler=scaler.state_dict() if c.mixed_precision else None,
)
if __name__ == '__main__':
if __name__ == "__main__":
args = parse_arguments(sys.argv)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
args, model_class='tts')
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts")
try:
main(args)

View File

@ -12,8 +12,7 @@ import torch
from torch.utils.data import DataLoader
from TTS.utils.arguments import parse_arguments, process_args
from TTS.utils.audio import AudioProcessor
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
remove_experiment_folder, set_init_dict)
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.radam import RAdam
@ -21,8 +20,7 @@ from TTS.utils.training import setup_torch_training_env
from TTS.vocoder.datasets.gan_dataset import GANDataset
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
from TTS.vocoder.utils.generic_utils import (plot_results, setup_discriminator,
setup_generator)
from TTS.vocoder.utils.generic_utils import plot_results, setup_discriminator, setup_generator
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
# DISTRIBUTED
@ -36,27 +34,30 @@ use_cuda, num_gpus = setup_torch_training_env(True, True)
def setup_loader(ap, is_val=False, verbose=False):
loader = None
if not is_val or c.run_eval:
dataset = GANDataset(ap=ap,
items=eval_data if is_val else train_data,
seq_len=c.seq_len,
hop_len=ap.hop_length,
pad_short=c.pad_short,
conv_pad=c.conv_pad,
is_training=not is_val,
return_segments=not is_val,
use_noise_augment=c.use_noise_augment,
use_cache=c.use_cache,
verbose=verbose)
dataset = GANDataset(
ap=ap,
items=eval_data if is_val else train_data,
seq_len=c.seq_len,
hop_len=ap.hop_length,
pad_short=c.pad_short,
conv_pad=c.conv_pad,
is_training=not is_val,
return_segments=not is_val,
use_noise_augment=c.use_noise_augment,
use_cache=c.use_cache,
verbose=verbose,
)
dataset.shuffle_mapping()
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
loader = DataLoader(dataset,
batch_size=1 if is_val else c.batch_size,
shuffle=num_gpus == 0,
drop_last=False,
sampler=sampler,
num_workers=c.num_val_loader_workers
if is_val else c.num_loader_workers,
pin_memory=False)
loader = DataLoader(
dataset,
batch_size=1 if is_val else c.batch_size,
shuffle=num_gpus == 0,
drop_last=False,
sampler=sampler,
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
pin_memory=False,
)
return loader
@ -83,16 +84,26 @@ def format_data(data):
return co, x, None, None
def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
scheduler_G, scheduler_D, ap, global_step, epoch):
def train(
model_G,
criterion_G,
optimizer_G,
model_D,
criterion_D,
optimizer_D,
scheduler_G,
scheduler_D,
ap,
global_step,
epoch,
):
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
model_G.train()
model_D.train()
epoch_time = 0
keep_avg = KeepAverage()
if use_cuda:
batch_n_iter = int(
len(data_loader.dataset) / (c.batch_size * num_gpus))
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
else:
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
end_time = time.time()
@ -148,16 +159,14 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
scores_fake = D_out_fake
# compute losses
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
feats_real, y_hat_sub, y_G_sub)
loss_G = loss_G_dict['G_loss']
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, feats_real, y_hat_sub, y_G_sub)
loss_G = loss_G_dict["G_loss"]
# optimizer generator
optimizer_G.zero_grad()
loss_G.backward()
if c.gen_clip_grad > 0:
torch.nn.utils.clip_grad_norm_(model_G.parameters(),
c.gen_clip_grad)
torch.nn.utils.clip_grad_norm_(model_G.parameters(), c.gen_clip_grad)
optimizer_G.step()
if scheduler_G is not None:
scheduler_G.step()
@ -202,14 +211,13 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
# compute losses
loss_D_dict = criterion_D(scores_fake, scores_real)
loss_D = loss_D_dict['D_loss']
loss_D = loss_D_dict["D_loss"]
# optimizer discriminator
optimizer_D.zero_grad()
loss_D.backward()
if c.disc_clip_grad > 0:
torch.nn.utils.clip_grad_norm_(model_D.parameters(),
c.disc_clip_grad)
torch.nn.utils.clip_grad_norm_(model_D.parameters(), c.disc_clip_grad)
optimizer_D.step()
if scheduler_D is not None:
scheduler_D.step()
@ -224,36 +232,31 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
epoch_time += step_time
# get current learning rates
current_lr_G = list(optimizer_G.param_groups)[0]['lr']
current_lr_D = list(optimizer_D.param_groups)[0]['lr']
current_lr_G = list(optimizer_G.param_groups)[0]["lr"]
current_lr_D = list(optimizer_D.param_groups)[0]["lr"]
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
update_train_values['avg_' + key] = value
update_train_values['avg_loader_time'] = loader_time
update_train_values['avg_step_time'] = step_time
update_train_values["avg_" + key] = value
update_train_values["avg_loader_time"] = loader_time
update_train_values["avg_step_time"] = step_time
keep_avg.update_values(update_train_values)
# print training stats
if global_step % c.print_step == 0:
log_dict = {
'step_time': [step_time, 2],
'loader_time': [loader_time, 4],
"step_time": [step_time, 2],
"loader_time": [loader_time, 4],
"current_lr_G": current_lr_G,
"current_lr_D": current_lr_D
"current_lr_D": current_lr_D,
}
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
log_dict, loss_dict, keep_avg.avg_values)
c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values)
if args.rank == 0:
# plot step stats
if global_step % 10 == 0:
iter_stats = {
"lr_G": current_lr_G,
"lr_D": current_lr_D,
"step_time": step_time
}
iter_stats = {"lr_G": current_lr_G, "lr_D": current_lr_D, "step_time": step_time}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
@ -261,27 +264,26 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
if global_step % c.save_step == 0:
if c.checkpoint:
# save model
save_checkpoint(model_G,
optimizer_G,
scheduler_G,
model_D,
optimizer_D,
scheduler_D,
global_step,
epoch,
OUT_PATH,
model_losses=loss_dict)
save_checkpoint(
model_G,
optimizer_G,
scheduler_G,
model_D,
optimizer_D,
scheduler_D,
global_step,
epoch,
OUT_PATH,
model_losses=loss_dict,
)
# compute spectrograms
figures = plot_results(y_hat_vis, y_G, ap, global_step,
'train')
figures = plot_results(y_hat_vis, y_G, ap, global_step, "train")
tb_logger.tb_train_figures(global_step, figures)
# Sample audio
sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_train_audios(global_step,
{'train/audio': sample_voice},
c.audio["sample_rate"])
tb_logger.tb_train_audios(global_step, {"train/audio": sample_voice}, c.audio["sample_rate"])
end_time = time.time()
# print epoch stats
@ -356,8 +358,7 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
feats_fake, feats_real = None, None
# compute losses
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
feats_real, y_hat_sub, y_G_sub)
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, feats_real, y_hat_sub, y_G_sub)
loss_dict = dict()
for key, value in loss_G_dict.items():
@ -413,9 +414,9 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
# update avg stats
update_eval_values = dict()
for key, value in loss_dict.items():
update_eval_values['avg_' + key] = value
update_eval_values['avg_loader_time'] = loader_time
update_eval_values['avg_step_time'] = step_time
update_eval_values["avg_" + key] = value
update_eval_values["avg_loader_time"] = loader_time
update_eval_values["avg_step_time"] = step_time
keep_avg.update_values(update_eval_values)
# print eval stats
@ -424,13 +425,12 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
if args.rank == 0:
# compute spectrograms
figures = plot_results(y_hat, y_G, ap, global_step, 'eval')
figures = plot_results(y_hat, y_G, ap, global_step, "eval")
tb_logger.tb_eval_figures(global_step, figures)
# Sample audio
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
c.audio["sample_rate"])
tb_logger.tb_eval_audios(global_step, {"eval/audio": sample_voice}, c.audio["sample_rate"])
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
@ -446,8 +446,7 @@ def main(args): # pylint: disable=redefined-outer-name
print(f" > Loading wavs from: {c.data_path}")
if c.feature_path is not None:
print(f" > Loading features from: {c.feature_path}")
eval_data, train_data = load_wav_feat_data(
c.data_path, c.feature_path, c.eval_split_size)
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
else:
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
@ -456,8 +455,7 @@ def main(args): # pylint: disable=redefined-outer-name
# DISTRUBUTED
if num_gpus > 1:
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"])
# setup models
model_gen = setup_generator(c)
@ -465,21 +463,17 @@ def main(args): # pylint: disable=redefined-outer-name
# setup optimizers
optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0)
optimizer_disc = RAdam(model_disc.parameters(),
lr=c.lr_disc,
weight_decay=0)
optimizer_disc = RAdam(model_disc.parameters(), lr=c.lr_disc, weight_decay=0)
# schedulers
scheduler_gen = None
scheduler_disc = None
if 'lr_scheduler_gen' in c:
if "lr_scheduler_gen" in c:
scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
scheduler_gen = scheduler_gen(
optimizer_gen, **c.lr_scheduler_gen_params)
if 'lr_scheduler_disc' in c:
scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params)
if "lr_scheduler_disc" in c:
scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
scheduler_disc = scheduler_disc(
optimizer_disc, **c.lr_scheduler_disc_params)
scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params)
# setup criterion
criterion_gen = GeneratorLoss(c)
@ -487,47 +481,46 @@ def main(args): # pylint: disable=redefined-outer-name
if args.restore_path:
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
checkpoint = torch.load(args.restore_path, map_location='cpu')
checkpoint = torch.load(args.restore_path, map_location="cpu")
try:
print(" > Restoring Generator Model...")
model_gen.load_state_dict(checkpoint['model'])
model_gen.load_state_dict(checkpoint["model"])
print(" > Restoring Generator Optimizer...")
optimizer_gen.load_state_dict(checkpoint['optimizer'])
optimizer_gen.load_state_dict(checkpoint["optimizer"])
print(" > Restoring Discriminator Model...")
model_disc.load_state_dict(checkpoint['model_disc'])
model_disc.load_state_dict(checkpoint["model_disc"])
print(" > Restoring Discriminator Optimizer...")
optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
if 'scheduler' in checkpoint and scheduler_gen is not None:
optimizer_disc.load_state_dict(checkpoint["optimizer_disc"])
if "scheduler" in checkpoint and scheduler_gen is not None:
print(" > Restoring Generator LR Scheduler...")
scheduler_gen.load_state_dict(checkpoint['scheduler'])
scheduler_gen.load_state_dict(checkpoint["scheduler"])
# NOTE: Not sure if necessary
scheduler_gen.optimizer = optimizer_gen
if 'scheduler_disc' in checkpoint and scheduler_disc is not None:
if "scheduler_disc" in checkpoint and scheduler_disc is not None:
print(" > Restoring Discriminator LR Scheduler...")
scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
scheduler_disc.load_state_dict(checkpoint["scheduler_disc"])
scheduler_disc.optimizer = optimizer_disc
except RuntimeError:
# restore only matching layers.
print(" > Partial model initialization...")
model_dict = model_gen.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
model_gen.load_state_dict(model_dict)
model_dict = model_disc.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model_disc'], c)
model_dict = set_init_dict(model_dict, checkpoint["model_disc"], c)
model_disc.load_state_dict(model_dict)
del model_dict
# reset lr if not countinuining training.
for group in optimizer_gen.param_groups:
group['lr'] = c.lr_gen
group["lr"] = c.lr_gen
for group in optimizer_disc.param_groups:
group['lr'] = c.lr_disc
group["lr"] = c.lr_disc
print(f" > Model restored from step {checkpoint['step']:d}",
flush=True)
args.restore_step = checkpoint['step']
print(f" > Model restored from step {checkpoint['step']:d}", flush=True)
args.restore_step = checkpoint["step"]
else:
args.restore_step = 0
@ -548,50 +541,55 @@ def main(args): # pylint: disable=redefined-outer-name
print(" > Discriminator has {} parameters".format(num_params), flush=True)
if args.restore_step == 0 or not args.best_path:
best_loss = float('inf')
best_loss = float("inf")
print(" > Starting with inf best loss.")
else:
print(" > Restoring best loss from "
f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path,
map_location='cpu')['model_loss']
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
print(f" > Starting with best loss of {best_loss}.")
keep_all_best = c.get('keep_all_best', False)
keep_after = c.get('keep_after', 10000) # void if keep_all_best False
keep_all_best = c.get("keep_all_best", False)
keep_after = c.get("keep_after", 10000) # void if keep_all_best False
global_step = args.restore_step
for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs)
_, global_step = train(model_gen, criterion_gen, optimizer_gen,
model_disc, criterion_disc, optimizer_disc,
scheduler_gen, scheduler_disc, ap, global_step,
epoch)
eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc,
criterion_disc, ap,
global_step, epoch)
_, global_step = train(
model_gen,
criterion_gen,
optimizer_gen,
model_disc,
criterion_disc,
optimizer_disc,
scheduler_gen,
scheduler_disc,
ap,
global_step,
epoch,
)
eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap, global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = eval_avg_loss_dict[c.target_loss]
best_loss = save_best_model(target_loss,
best_loss,
model_gen,
optimizer_gen,
scheduler_gen,
model_disc,
optimizer_disc,
scheduler_disc,
global_step,
epoch,
OUT_PATH,
keep_all_best=keep_all_best,
keep_after=keep_after,
model_losses=eval_avg_loss_dict,
)
best_loss = save_best_model(
target_loss,
best_loss,
model_gen,
optimizer_gen,
scheduler_gen,
model_disc,
optimizer_disc,
scheduler_disc,
global_step,
epoch,
OUT_PATH,
keep_all_best=keep_all_best,
keep_after=keep_after,
model_losses=eval_avg_loss_dict,
)
if __name__ == '__main__':
if __name__ == "__main__":
args = parse_arguments(sys.argv)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
args, model_class='vocoder')
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="vocoder")
try:
main(args)

View File

@ -8,6 +8,7 @@ import traceback
import numpy as np
import torch
# DISTRIBUTED
from torch.nn.parallel import DistributedDataParallel as DDP_th
from torch.optim import Adam
@ -16,8 +17,7 @@ from torch.utils.data.distributed import DistributedSampler
from TTS.utils.arguments import parse_arguments, process_args
from TTS.utils.audio import AudioProcessor
from TTS.utils.distribute import init_distributed
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
remove_experiment_folder, set_init_dict)
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
from TTS.utils.training import setup_torch_training_env
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
@ -31,27 +31,29 @@ def setup_loader(ap, is_val=False, verbose=False):
if is_val and not c.run_eval:
loader = None
else:
dataset = WaveGradDataset(ap=ap,
items=eval_data if is_val else train_data,
seq_len=c.seq_len,
hop_len=ap.hop_length,
pad_short=c.pad_short,
conv_pad=c.conv_pad,
is_training=not is_val,
return_segments=True,
use_noise_augment=False,
use_cache=c.use_cache,
verbose=verbose)
dataset = WaveGradDataset(
ap=ap,
items=eval_data if is_val else train_data,
seq_len=c.seq_len,
hop_len=ap.hop_length,
pad_short=c.pad_short,
conv_pad=c.conv_pad,
is_training=not is_val,
return_segments=True,
use_noise_augment=False,
use_cache=c.use_cache,
verbose=verbose,
)
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(dataset,
batch_size=c.batch_size,
shuffle=num_gpus <= 1,
drop_last=False,
sampler=sampler,
num_workers=c.num_val_loader_workers
if is_val else c.num_loader_workers,
pin_memory=False)
loader = DataLoader(
dataset,
batch_size=c.batch_size,
shuffle=num_gpus <= 1,
drop_last=False,
sampler=sampler,
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
pin_memory=False,
)
return loader
@ -77,24 +79,21 @@ def format_test_data(data):
return m, x
def train(model, criterion, optimizer, scheduler, scaler, ap, global_step,
epoch):
def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch):
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
model.train()
epoch_time = 0
keep_avg = KeepAverage()
if use_cuda:
batch_n_iter = int(
len(data_loader.dataset) / (c.batch_size * num_gpus))
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
else:
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
end_time = time.time()
c_logger.print_train_start()
# setup noise schedule
noise_schedule = c['train_noise_schedule']
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'],
noise_schedule['num_steps'])
if hasattr(model, 'module'):
noise_schedule = c["train_noise_schedule"]
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
if hasattr(model, "module"):
model.module.compute_noise_level(betas)
else:
model.compute_noise_level(betas)
@ -109,7 +108,7 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step,
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
# compute noisy input
if hasattr(model, 'module'):
if hasattr(model, "module"):
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
else:
noise, x_noisy, noise_scale = model.compute_y_n(x)
@ -119,11 +118,11 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step,
# compute losses
loss = criterion(noise, noise_hat)
loss_wavegrad_dict = {'wavegrad_loss': loss}
loss_wavegrad_dict = {"wavegrad_loss": loss}
# check nan loss
if torch.isnan(loss).any():
raise RuntimeError(f'Detected NaN loss at step {global_step}.')
raise RuntimeError(f"Detected NaN loss at step {global_step}.")
optimizer.zero_grad()
@ -131,14 +130,12 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step,
if c.mixed_precision:
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.clip_grad)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.clip_grad)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
c.clip_grad)
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.clip_grad)
optimizer.step()
# schedule update
@ -158,35 +155,30 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step,
epoch_time += step_time
# get current learning rates
current_lr = list(optimizer.param_groups)[0]['lr']
current_lr = list(optimizer.param_groups)[0]["lr"]
# update avg stats
update_train_values = dict()
for key, value in loss_dict.items():
update_train_values['avg_' + key] = value
update_train_values['avg_loader_time'] = loader_time
update_train_values['avg_step_time'] = step_time
update_train_values["avg_" + key] = value
update_train_values["avg_loader_time"] = loader_time
update_train_values["avg_step_time"] = step_time
keep_avg.update_values(update_train_values)
# print training stats
if global_step % c.print_step == 0:
log_dict = {
'step_time': [step_time, 2],
'loader_time': [loader_time, 4],
"step_time": [step_time, 2],
"loader_time": [loader_time, 4],
"current_lr": current_lr,
"grad_norm": grad_norm.item()
"grad_norm": grad_norm.item(),
}
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
log_dict, loss_dict, keep_avg.avg_values)
c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values)
if args.rank == 0:
# plot step stats
if global_step % 10 == 0:
iter_stats = {
"lr": current_lr,
"grad_norm": grad_norm.item(),
"step_time": step_time
}
iter_stats = {"lr": current_lr, "grad_norm": grad_norm.item(), "step_time": step_time}
iter_stats.update(loss_dict)
tb_logger.tb_train_iter_stats(global_step, iter_stats)
@ -205,7 +197,7 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step,
epoch,
OUT_PATH,
model_losses=loss_dict,
scaler=scaler.state_dict() if c.mixed_precision else None
scaler=scaler.state_dict() if c.mixed_precision else None,
)
end_time = time.time()
@ -242,19 +234,17 @@ def evaluate(model, criterion, ap, global_step, epoch):
global_step += 1
# compute noisy input
if hasattr(model, 'module'):
if hasattr(model, "module"):
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
else:
noise, x_noisy, noise_scale = model.compute_y_n(x)
# forward pass
noise_hat = model(x_noisy, m, noise_scale)
# compute losses
loss = criterion(noise, noise_hat)
loss_wavegrad_dict = {'wavegrad_loss': loss}
loss_wavegrad_dict = {"wavegrad_loss": loss}
loss_dict = dict()
for key, value in loss_wavegrad_dict.items():
@ -269,9 +259,9 @@ def evaluate(model, criterion, ap, global_step, epoch):
# update avg stats
update_eval_values = dict()
for key, value in loss_dict.items():
update_eval_values['avg_' + key] = value
update_eval_values['avg_loader_time'] = loader_time
update_eval_values['avg_step_time'] = step_time
update_eval_values["avg_" + key] = value
update_eval_values["avg_loader_time"] = loader_time
update_eval_values["avg_step_time"] = step_time
keep_avg.update_values(update_eval_values)
# print eval stats
@ -284,11 +274,9 @@ def evaluate(model, criterion, ap, global_step, epoch):
m, x = format_test_data(samples[0])
# setup noise schedule and inference
noise_schedule = c['test_noise_schedule']
betas = np.linspace(noise_schedule['min_val'],
noise_schedule['max_val'],
noise_schedule['num_steps'])
if hasattr(model, 'module'):
noise_schedule = c["test_noise_schedule"]
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
if hasattr(model, "module"):
model.module.compute_noise_level(betas)
# compute voice
x_pred = model.module.inference(m)
@ -298,13 +286,12 @@ def evaluate(model, criterion, ap, global_step, epoch):
x_pred = model.inference(m)
# compute spectrograms
figures = plot_results(x_pred, x, ap, global_step, 'eval')
figures = plot_results(x_pred, x, ap, global_step, "eval")
tb_logger.tb_eval_figures(global_step, figures)
# Sample audio
sample_voice = x_pred[0].squeeze(0).detach().cpu().numpy()
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
c.audio["sample_rate"])
tb_logger.tb_eval_audios(global_step, {"eval/audio": sample_voice}, c.audio["sample_rate"])
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
data_loader.dataset.return_segments = True
@ -318,8 +305,7 @@ def main(args): # pylint: disable=redefined-outer-name
print(f" > Loading wavs from: {c.data_path}")
if c.feature_path is not None:
print(f" > Loading features from: {c.feature_path}")
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path,
c.eval_split_size)
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
else:
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
@ -328,8 +314,7 @@ def main(args): # pylint: disable=redefined-outer-name
# DISTRUBUTED
if num_gpus > 1:
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"])
# setup models
model = setup_generator(c)
@ -342,7 +327,7 @@ def main(args): # pylint: disable=redefined-outer-name
# schedulers
scheduler = None
if 'lr_scheduler' in c:
if "lr_scheduler" in c:
scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
scheduler = scheduler(optimizer, **c.lr_scheduler_params)
@ -355,15 +340,15 @@ def main(args): # pylint: disable=redefined-outer-name
if args.restore_path:
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
checkpoint = torch.load(args.restore_path, map_location='cpu')
checkpoint = torch.load(args.restore_path, map_location="cpu")
try:
print(" > Restoring Model...")
model.load_state_dict(checkpoint['model'])
model.load_state_dict(checkpoint["model"])
print(" > Restoring Optimizer...")
optimizer.load_state_dict(checkpoint['optimizer'])
if 'scheduler' in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer"])
if "scheduler" in checkpoint:
print(" > Restoring LR Scheduler...")
scheduler.load_state_dict(checkpoint['scheduler'])
scheduler.load_state_dict(checkpoint["scheduler"])
# NOTE: Not sure if necessary
scheduler.optimizer = optimizer
if "scaler" in checkpoint and c.mixed_precision:
@ -373,17 +358,16 @@ def main(args): # pylint: disable=redefined-outer-name
# retore only matching layers.
print(" > Partial model initialization...")
model_dict = model.state_dict()
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
model.load_state_dict(model_dict)
del model_dict
# reset lr if not countinuining training.
for group in optimizer.param_groups:
group['lr'] = c.lr
group["lr"] = c.lr
print(" > Model restored from step %d" % checkpoint['step'],
flush=True)
args.restore_step = checkpoint['step']
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
args.restore_step = checkpoint["step"]
else:
args.restore_step = 0
@ -395,22 +379,19 @@ def main(args): # pylint: disable=redefined-outer-name
print(" > WaveGrad has {} parameters".format(num_params), flush=True)
if args.restore_step == 0 or not args.best_path:
best_loss = float('inf')
best_loss = float("inf")
print(" > Starting with inf best loss.")
else:
print(" > Restoring best loss from "
f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path,
map_location='cpu')['model_loss']
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
print(f" > Starting with loaded last best loss {best_loss}.")
keep_all_best = c.get('keep_all_best', False)
keep_after = c.get('keep_after', 10000) # void if keep_all_best False
keep_all_best = c.get("keep_all_best", False)
keep_after = c.get("keep_after", 10000) # void if keep_all_best False
global_step = args.restore_step
for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs)
_, global_step = train(model, criterion, optimizer, scheduler, scaler,
ap, global_step, epoch)
_, global_step = train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch)
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = eval_avg_loss_dict[c.target_loss]
@ -429,14 +410,13 @@ def main(args): # pylint: disable=redefined-outer-name
keep_all_best=keep_all_best,
keep_after=keep_after,
model_losses=eval_avg_loss_dict,
scaler=scaler.state_dict() if c.mixed_precision else None
scaler=scaler.state_dict() if c.mixed_precision else None,
)
if __name__ == '__main__':
if __name__ == "__main__":
args = parse_arguments(sys.argv)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
args, model_class='vocoder')
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="vocoder")
try:
main(args)

View File

@ -24,10 +24,7 @@ from TTS.utils.generic_utils import (
set_init_dict,
)
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.datasets.preprocess import (
load_wav_data,
load_wav_feat_data
)
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
from TTS.vocoder.utils.generic_utils import setup_generator
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
@ -40,26 +37,26 @@ def setup_loader(ap, is_val=False, verbose=False):
if is_val and not c.run_eval:
loader = None
else:
dataset = WaveRNNDataset(ap=ap,
items=eval_data if is_val else train_data,
seq_len=c.seq_len,
hop_len=ap.hop_length,
pad=c.padding,
mode=c.mode,
mulaw=c.mulaw,
is_training=not is_val,
verbose=verbose,
)
dataset = WaveRNNDataset(
ap=ap,
items=eval_data if is_val else train_data,
seq_len=c.seq_len,
hop_len=ap.hop_length,
pad=c.padding,
mode=c.mode,
mulaw=c.mulaw,
is_training=not is_val,
verbose=verbose,
)
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
loader = DataLoader(dataset,
shuffle=True,
collate_fn=dataset.collate,
batch_size=c.batch_size,
num_workers=c.num_val_loader_workers
if is_val
else c.num_loader_workers,
pin_memory=True,
)
loader = DataLoader(
dataset,
shuffle=True,
collate_fn=dataset.collate,
batch_size=c.batch_size,
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
pin_memory=True,
)
return loader
@ -85,8 +82,7 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
epoch_time = 0
keep_avg = KeepAverage()
if use_cuda:
batch_n_iter = int(len(data_loader.dataset) /
(c.batch_size * num_gpus))
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
else:
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
end_time = time.time()
@ -114,8 +110,7 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
if c.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), c.grad_clip)
torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
scaler.step(optimizer)
scaler.update()
else:
@ -132,8 +127,7 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
raise RuntimeError(" [!] None loss. Exiting ...")
loss.backward()
if c.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(
model.parameters(), c.grad_clip)
torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
optimizer.step()
if scheduler is not None:
@ -156,17 +150,19 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
# print training stats
if global_step % c.print_step == 0:
log_dict = {"step_time": [step_time, 2],
"loader_time": [loader_time, 4],
"current_lr": cur_lr,
}
c_logger.print_train_step(batch_n_iter,
num_iter,
global_step,
log_dict,
loss_dict,
keep_avg.avg_values,
)
log_dict = {
"step_time": [step_time, 2],
"loader_time": [loader_time, 4],
"current_lr": cur_lr,
}
c_logger.print_train_step(
batch_n_iter,
num_iter,
global_step,
log_dict,
loss_dict,
keep_avg.avg_values,
)
# plot step stats
if global_step % 10 == 0:
@ -189,36 +185,36 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
epoch,
OUT_PATH,
model_losses=loss_dict,
scaler=scaler.state_dict() if c.mixed_precision else None
scaler=scaler.state_dict() if c.mixed_precision else None,
)
# synthesize a full voice
rand_idx = random.randrange(0, len(train_data))
wav_path = train_data[rand_idx] if not isinstance(
train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0]
wav_path = (
train_data[rand_idx] if not isinstance(train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0]
)
wav = ap.load_wav(wav_path)
ground_mel = ap.melspectrogram(wav)
ground_mel = torch.FloatTensor(ground_mel)
if use_cuda:
ground_mel = ground_mel.cuda(non_blocking=True)
sample_wav = model.inference(ground_mel,
c.batched,
c.target_samples,
c.overlap_samples,
)
sample_wav = model.inference(
ground_mel,
c.batched,
c.target_samples,
c.overlap_samples,
)
predict_mel = ap.melspectrogram(sample_wav)
# compute spectrograms
figures = {"train/ground_truth": plot_spectrogram(ground_mel.T),
"train/prediction": plot_spectrogram(predict_mel.T)
}
figures = {
"train/ground_truth": plot_spectrogram(ground_mel.T),
"train/prediction": plot_spectrogram(predict_mel.T),
}
tb_logger.tb_train_figures(global_step, figures)
# Sample audio
tb_logger.tb_train_audios(
global_step, {
"train/audio": sample_wav}, c.audio["sample_rate"]
)
tb_logger.tb_train_audios(global_step, {"train/audio": sample_wav}, c.audio["sample_rate"])
end_time = time.time()
# print epoch stats
@ -277,36 +273,32 @@ def evaluate(model, criterion, ap, global_step, epoch):
# print eval stats
if c.print_eval:
c_logger.print_eval_step(
num_iter, loss_dict, keep_avg.avg_values)
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
if epoch % c.test_every_epochs == 0 and epoch != 0:
# synthesize a full voice
rand_idx = random.randrange(0, len(eval_data))
wav_path = eval_data[rand_idx] if not isinstance(
eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0]
wav_path = eval_data[rand_idx] if not isinstance(eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0]
wav = ap.load_wav(wav_path)
ground_mel = ap.melspectrogram(wav)
ground_mel = torch.FloatTensor(ground_mel)
if use_cuda:
ground_mel = ground_mel.cuda(non_blocking=True)
sample_wav = model.inference(ground_mel,
c.batched,
c.target_samples,
c.overlap_samples,
)
sample_wav = model.inference(
ground_mel,
c.batched,
c.target_samples,
c.overlap_samples,
)
predict_mel = ap.melspectrogram(sample_wav)
# Sample audio
tb_logger.tb_eval_audios(
global_step, {
"eval/audio": sample_wav}, c.audio["sample_rate"]
)
tb_logger.tb_eval_audios(global_step, {"eval/audio": sample_wav}, c.audio["sample_rate"])
# compute spectrograms
figures = {
"eval/ground_truth": plot_spectrogram(ground_mel.T),
"eval/prediction": plot_spectrogram(predict_mel.T)
"eval/prediction": plot_spectrogram(predict_mel.T),
}
tb_logger.tb_eval_figures(global_step, figures)
@ -347,11 +339,9 @@ def main(args): # pylint: disable=redefined-outer-name
print(f" > Loading wavs from: {c.data_path}")
if c.feature_path is not None:
print(f" > Loading features from: {c.feature_path}")
eval_data, train_data = load_wav_feat_data(
c.data_path, c.feature_path, c.eval_split_size)
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
else:
eval_data, train_data = load_wav_data(
c.data_path, c.eval_split_size)
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
# setup model
model_wavernn = setup_generator(c)
@ -404,8 +394,7 @@ def main(args): # pylint: disable=redefined-outer-name
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
model_wavernn.load_state_dict(model_dict)
print(" > Model restored from step %d" %
checkpoint["step"], flush=True)
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
args.restore_step = checkpoint["step"]
else:
args.restore_step = 0
@ -418,24 +407,20 @@ def main(args): # pylint: disable=redefined-outer-name
print(" > Model has {} parameters".format(num_parameters), flush=True)
if args.restore_step == 0 or not args.best_path:
best_loss = float('inf')
best_loss = float("inf")
print(" > Starting with inf best loss.")
else:
print(" > Restoring best loss from "
f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path,
map_location='cpu')['model_loss']
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
print(f" > Starting with loaded last best loss {best_loss}.")
keep_all_best = c.get('keep_all_best', False)
keep_after = c.get('keep_after', 10000) # void if keep_all_best False
keep_all_best = c.get("keep_all_best", False)
keep_after = c.get("keep_after", 10000) # void if keep_all_best False
global_step = args.restore_step
for epoch in range(0, c.epochs):
c_logger.print_epoch_start(epoch, c.epochs)
_, global_step = train(model_wavernn, optimizer,
criterion, scheduler, scaler, ap, global_step, epoch)
eval_avg_loss_dict = evaluate(
model_wavernn, criterion, ap, global_step, epoch)
_, global_step = train(model_wavernn, optimizer, criterion, scheduler, scaler, ap, global_step, epoch)
eval_avg_loss_dict = evaluate(model_wavernn, criterion, ap, global_step, epoch)
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
target_loss = eval_avg_loss_dict["avg_model_loss"]
best_loss = save_best_model(
@ -453,14 +438,13 @@ def main(args): # pylint: disable=redefined-outer-name
keep_all_best=keep_all_best,
keep_after=keep_after,
model_losses=eval_avg_loss_dict,
scaler=scaler.state_dict() if c.mixed_precision else None
scaler=scaler.state_dict() if c.mixed_precision else None,
)
if __name__ == "__main__":
args = parse_arguments(sys.argv)
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
args, model_class='vocoder')
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="vocoder")
try:
main(args)

View File

@ -13,14 +13,21 @@ from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
from TTS.vocoder.utils.generic_utils import setup_generator
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, help='Path to model checkpoint.')
parser.add_argument('--config_path', type=str, help='Path to model config file.')
parser.add_argument('--data_path', type=str, help='Path to data directory.')
parser.add_argument('--output_path', type=str, help='path for output file including file name and extension.')
parser.add_argument('--num_iter', type=int, help='Number of model inference iterations that you like to optimize noise schedule for.')
parser.add_argument('--use_cuda', type=bool, help='enable/disable CUDA.')
parser.add_argument('--num_samples', type=int, default=1, help='Number of datasamples used for inference.')
parser.add_argument('--search_depth', type=int, default=3, help='Search granularity. Increasing this increases the run-time exponentially.')
parser.add_argument("--model_path", type=str, help="Path to model checkpoint.")
parser.add_argument("--config_path", type=str, help="Path to model config file.")
parser.add_argument("--data_path", type=str, help="Path to data directory.")
parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.")
parser.add_argument(
"--num_iter", type=int, help="Number of model inference iterations that you like to optimize noise schedule for."
)
parser.add_argument("--use_cuda", type=bool, help="enable/disable CUDA.")
parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.")
parser.add_argument(
"--search_depth",
type=int,
default=3,
help="Search granularity. Increasing this increases the run-time exponentially.",
)
# load config
args = parser.parse_args()
@ -31,18 +38,20 @@ ap = AudioProcessor(**config.audio)
# load dataset
_, train_data = load_wav_data(args.data_path, 0)
train_data = train_data[:args.num_samples]
dataset = WaveGradDataset(ap=ap,
items=train_data,
seq_len=-1,
hop_len=ap.hop_length,
pad_short=config.pad_short,
conv_pad=config.conv_pad,
is_training=True,
return_segments=False,
use_noise_augment=False,
use_cache=False,
verbose=True)
train_data = train_data[: args.num_samples]
dataset = WaveGradDataset(
ap=ap,
items=train_data,
seq_len=-1,
hop_len=ap.hop_length,
pad_short=config.pad_short,
conv_pad=config.conv_pad,
is_training=True,
return_segments=False,
use_noise_augment=False,
use_cache=False,
verbose=True,
)
loader = DataLoader(
dataset,
batch_size=1,
@ -50,7 +59,8 @@ loader = DataLoader(
collate_fn=dataset.collate_full_clips,
drop_last=False,
num_workers=config.num_loader_workers,
pin_memory=False)
pin_memory=False,
)
# setup the model
model = setup_generator(config)
@ -61,9 +71,9 @@ if args.use_cuda:
base_values = sorted(10 * np.random.uniform(size=args.search_depth))
print(base_values)
exponents = 10 ** np.linspace(-6, -1, num=args.num_iter)
best_error = float('inf')
best_error = float("inf")
best_schedule = None
total_search_iter = len(base_values)**args.num_iter
total_search_iter = len(base_values) ** args.num_iter
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
beta = exponents * base
model.compute_noise_level(beta)
@ -84,6 +94,6 @@ for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=tot
mse = torch.sum((mel - mel_hat) ** 2).mean()
if mse.item() < best_error:
best_error = mse.item()
best_schedule = {'beta': beta}
best_schedule = {"beta": beta}
print(f" > Found a better schedule. - MSE: {mse.item()}")
np.save(args.output_path, best_schedule)

View File

@ -13,23 +13,40 @@ from TTS.utils.io import load_config
def create_argparser():
def convert_boolean(x):
return x.lower() in ['true', '1', 'yes']
return x.lower() in ["true", "1", "yes"]
parser = argparse.ArgumentParser()
parser.add_argument('--list_models', type=convert_boolean, nargs='?', const=True, default=False, help='list available pre-trained tts and vocoder models.')
parser.add_argument('--model_name', type=str, default="tts_models/en/ljspeech/speedy-speech-wn", help='name of one of the released tts models.')
parser.add_argument('--vocoder_name', type=str, default=None, help='name of one of the released vocoder models.')
parser.add_argument('--tts_checkpoint', type=str, help='path to custom tts checkpoint file')
parser.add_argument('--tts_config', type=str, help='path to custom tts config.json file')
parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model')
parser.add_argument('--vocoder_config', type=str, default=None, help='path to vocoder config file.')
parser.add_argument('--vocoder_checkpoint', type=str, default=None, help='path to vocoder checkpoint file.')
parser.add_argument('--port', type=int, default=5002, help='port to listen on.')
parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.')
parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.')
parser.add_argument('--show_details', type=convert_boolean, default=False, help='Generate model detail page.')
parser.add_argument(
"--list_models",
type=convert_boolean,
nargs="?",
const=True,
default=False,
help="list available pre-trained tts and vocoder models.",
)
parser.add_argument(
"--model_name",
type=str,
default="tts_models/en/ljspeech/speedy-speech-wn",
help="name of one of the released tts models.",
)
parser.add_argument("--vocoder_name", type=str, default=None, help="name of one of the released vocoder models.")
parser.add_argument("--tts_checkpoint", type=str, help="path to custom tts checkpoint file")
parser.add_argument("--tts_config", type=str, help="path to custom tts config.json file")
parser.add_argument(
"--tts_speakers",
type=str,
help="path to JSON file containing speaker ids, if speaker ids are used in the model",
)
parser.add_argument("--vocoder_config", type=str, default=None, help="path to vocoder config file.")
parser.add_argument("--vocoder_checkpoint", type=str, default=None, help="path to vocoder checkpoint file.")
parser.add_argument("--port", type=int, default=5002, help="port to listen on.")
parser.add_argument("--use_cuda", type=convert_boolean, default=False, help="true to use CUDA.")
parser.add_argument("--debug", type=convert_boolean, default=False, help="true to enable Flask debug mode.")
parser.add_argument("--show_details", type=convert_boolean, default=False, help="Generate model detail page.")
return parser
# parse the args
args = create_argparser().parse_args()
@ -43,7 +60,7 @@ if args.list_models:
# update in-use models to the specified released models.
if args.model_name is not None:
tts_checkpoint_file, tts_config_file, tts_json_dict = manager.download_model(args.model_name)
args.vocoder_name = tts_json_dict['default_vocoder'] if args.vocoder_name is None else args.vocoder_name
args.vocoder_name = tts_json_dict["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
if args.vocoder_name is not None:
vocoder_checkpoint_file, vocoder_config_file, vocoder_json_dict = manager.download_model(args.vocoder_name)
@ -59,16 +76,19 @@ if not args.vocoder_checkpoint and os.path.isfile(vocoder_checkpoint_file):
if not args.vocoder_config and os.path.isfile(vocoder_config_file):
args.vocoder_config = vocoder_config_file
synthesizer = Synthesizer(args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda)
synthesizer = Synthesizer(
args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda
)
app = Flask(__name__)
@app.route('/')
@app.route("/")
def index():
return render_template('index.html', show_details=args.show_details)
return render_template("index.html", show_details=args.show_details)
@app.route('/details')
@app.route("/details")
def details():
model_config = load_config(args.tts_config)
if args.vocoder_config is not None and os.path.isfile(args.vocoder_config):
@ -76,26 +96,28 @@ def details():
else:
vocoder_config = None
return render_template('details.html',
show_details=args.show_details
, model_config=model_config
, vocoder_config=vocoder_config
, args=args.__dict__
)
return render_template(
"details.html",
show_details=args.show_details,
model_config=model_config,
vocoder_config=vocoder_config,
args=args.__dict__,
)
@app.route('/api/tts', methods=['GET'])
@app.route("/api/tts", methods=["GET"])
def tts():
text = request.args.get('text')
text = request.args.get("text")
print(" > Model input: {}".format(text))
wavs = synthesizer.tts(text)
out = io.BytesIO()
synthesizer.save_wav(wavs, out)
return send_file(out, mimetype='audio/wav')
return send_file(out, mimetype="audio/wav")
def main():
app.run(debug=args.debug, host='0.0.0.0', port=args.port)
app.run(debug=args.debug, host="0.0.0.0", port=args.port)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -7,9 +7,19 @@ from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, ap, meta_data, voice_len=1.6, num_speakers_in_batch=64,
storage_size=1, sample_from_storage_p=0.5, additive_noise=0,
num_utter_per_speaker=10, skip_speakers=False, verbose=False):
def __init__(
self,
ap,
meta_data,
voice_len=1.6,
num_speakers_in_batch=64,
storage_size=1,
sample_from_storage_p=0.5,
additive_noise=0,
num_utter_per_speaker=10,
skip_speakers=False,
verbose=False,
):
"""
Args:
ap (TTS.tts.utils.AudioProcessor): audio processor object.
@ -28,7 +38,7 @@ class MyDataset(Dataset):
self.ap = ap
self.verbose = verbose
self.__parse_items()
self.storage = queue.Queue(maxsize=storage_size*num_speakers_in_batch)
self.storage = queue.Queue(maxsize=storage_size * num_speakers_in_batch)
self.sample_from_storage_p = float(sample_from_storage_p)
self.additive_noise = float(additive_noise)
if self.verbose:
@ -69,11 +79,14 @@ class MyDataset(Dataset):
if speaker_ in self.speaker_to_utters.keys():
self.speaker_to_utters[speaker_].append(path_)
else:
self.speaker_to_utters[speaker_] = [path_, ]
self.speaker_to_utters[speaker_] = [
path_,
]
if self.skip_speakers:
self.speaker_to_utters = {k: v for (k, v) in self.speaker_to_utters.items() if
len(v) >= self.num_utter_per_speaker}
self.speaker_to_utters = {
k: v for (k, v) in self.speaker_to_utters.items() if len(v) >= self.num_utter_per_speaker
}
self.speakers = [k for (k, v) in self.speaker_to_utters.items()]
@ -100,13 +113,9 @@ class MyDataset(Dataset):
def __sample_speaker(self):
speaker = random.sample(self.speakers, 1)[0]
if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]):
utters = random.choices(
self.speaker_to_utters[speaker], k=self.num_utter_per_speaker
)
utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker)
else:
utters = random.sample(
self.speaker_to_utters[speaker], self.num_utter_per_speaker
)
utters = random.sample(self.speaker_to_utters[speaker], self.num_utter_per_speaker)
return speaker, utters
def __sample_speaker_utterances(self, speaker):
@ -160,7 +169,9 @@ class MyDataset(Dataset):
# get a random subset of each of the wavs and convert to MFCC.
offsets_ = [random.randint(0, wav.shape[0] - self.seq_len) for wav in wavs_]
mels_ = [self.ap.melspectrogram(wavs_[i][offsets_[i]: offsets_[i] + self.seq_len]) for i in range(len(wavs_))]
mels_ = [
self.ap.melspectrogram(wavs_[i][offsets_[i] : offsets_[i] + self.seq_len]) for i in range(len(wavs_))
]
feats_ = [torch.FloatTensor(mel) for mel in mels_]
labels.append(labels_)

View File

@ -16,14 +16,14 @@ class GE2ELoss(nn.Module):
- init_w (float): defines the initial value of w in Equation (5) of [1]
- init_b (float): definies the initial value of b in Equation (5) of [1]
"""
super(GE2ELoss, self).__init__()
super().__init__()
# pylint: disable=E1102
self.w = nn.Parameter(torch.tensor(init_w))
# pylint: disable=E1102
self.b = nn.Parameter(torch.tensor(init_b))
self.loss_method = loss_method
print(' > Initialised Generalized End-to-End loss')
print(" > Initialised Generalized End-to-End loss")
assert self.loss_method in ["softmax", "contrast"]
@ -55,9 +55,7 @@ class GE2ELoss(nn.Module):
for spkr_idx, speaker in enumerate(dvecs):
cs_row = []
for utt_idx, utterance in enumerate(speaker):
new_centroids = self.calc_new_centroids(
dvecs, centroids, spkr_idx, utt_idx
)
new_centroids = self.calc_new_centroids(dvecs, centroids, spkr_idx, utt_idx)
# vector based cosine similarity for speed
cs_row.append(
torch.clamp(
@ -99,14 +97,8 @@ class GE2ELoss(nn.Module):
L_row = []
for i in range(M):
centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i])
excl_centroids_sigmoids = torch.cat(
(centroids_sigmoids[:j], centroids_sigmoids[j + 1 :])
)
L_row.append(
1.0
- torch.sigmoid(cos_sim_matrix[j, i, j])
+ torch.max(excl_centroids_sigmoids)
)
excl_centroids_sigmoids = torch.cat((centroids_sigmoids[:j], centroids_sigmoids[j + 1 :]))
L_row.append(1.0 - torch.sigmoid(cos_sim_matrix[j, i, j]) + torch.max(excl_centroids_sigmoids))
L_row = torch.stack(L_row)
L.append(L_row)
return torch.stack(L)
@ -122,6 +114,7 @@ class GE2ELoss(nn.Module):
L = self.embed_loss(dvecs, cos_sim_matrix)
return L.mean()
# adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py
class AngleProtoLoss(nn.Module):
"""
@ -134,15 +127,16 @@ class AngleProtoLoss(nn.Module):
- init_w (float): defines the initial value of w
- init_b (float): definies the initial value of b
"""
def __init__(self, init_w=10.0, init_b=-5.0):
super(AngleProtoLoss, self).__init__()
super().__init__()
# pylint: disable=E1102
self.w = nn.Parameter(torch.tensor(init_w))
# pylint: disable=E1102
self.b = nn.Parameter(torch.tensor(init_b))
self.criterion = torch.nn.CrossEntropyLoss()
print(' > Initialised Angular Prototypical loss')
print(" > Initialised Angular Prototypical loss")
def forward(self, x):
"""
@ -152,7 +146,10 @@ class AngleProtoLoss(nn.Module):
out_positive = x[:, 0, :]
num_speakers = out_anchor.size()[0]
cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1, -1, num_speakers), out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2))
cos_sim_matrix = F.cosine_similarity(
out_positive.unsqueeze(-1).expand(-1, -1, num_speakers),
out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2),
)
torch.clamp(self.w, 1e-6)
cos_sim_matrix = cos_sim_matrix * self.w + self.b
label = torch.arange(num_speakers).to(cos_sim_matrix.device)

View File

@ -16,19 +16,19 @@ class LSTMWithProjection(nn.Module):
o, (_, _) = self.lstm(x)
return self.linear(o)
class LSTMWithoutProjection(nn.Module):
def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers):
super().__init__()
self.lstm = nn.LSTM(input_size=input_dim,
hidden_size=lstm_dim,
num_layers=num_lstm_layers,
batch_first=True)
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True)
self.linear = nn.Linear(lstm_dim, proj_dim, bias=True)
self.relu = nn.ReLU()
def forward(self, x):
_, (hidden, _) = self.lstm(x)
return self.relu(self.linear(hidden[-1]))
class SpeakerEncoder(nn.Module):
def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True):
super().__init__()
@ -106,7 +106,5 @@ class SpeakerEncoder(nn.Module):
if embed is None:
embed = self.inference(frames)
else:
embed[cur_iter <= num_iters, :] += self.inference(
frames[cur_iter <= num_iters, :, :]
)
embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :])
return embed / num_iters

View File

@ -9,108 +9,114 @@ from TTS.utils.generic_utils import check_argument
def to_camel(text):
text = text.capitalize()
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
def setup_model(c):
model = SpeakerEncoder(c.model['input_dim'], c.model['proj_dim'],
c.model['lstm_dim'], c.model['num_lstm_layers'])
model = SpeakerEncoder(c.model["input_dim"], c.model["proj_dim"], c.model["lstm_dim"], c.model["num_lstm_layers"])
return model
def save_checkpoint(model, optimizer, model_loss, out_path,
current_step, epoch):
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
def save_checkpoint(model, optimizer, model_loss, out_path, current_step, epoch):
checkpoint_path = "checkpoint_{}.pth.tar".format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
new_state_dict = model.state_dict()
state = {
'model': new_state_dict,
'optimizer': optimizer.state_dict() if optimizer is not None else None,
'step': current_step,
'epoch': epoch,
'loss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y"),
"model": new_state_dict,
"optimizer": optimizer.state_dict() if optimizer is not None else None,
"step": current_step,
"epoch": epoch,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
torch.save(state, checkpoint_path)
def save_best_model(model, optimizer, model_loss, best_loss, out_path,
current_step):
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
if model_loss < best_loss:
new_state_dict = model.state_dict()
state = {
'model': new_state_dict,
'optimizer': optimizer.state_dict(),
'step': current_step,
'loss': model_loss,
'date': datetime.date.today().strftime("%B %d, %Y"),
"model": new_state_dict,
"optimizer": optimizer.state_dict(),
"step": current_step,
"loss": model_loss,
"date": datetime.date.today().strftime("%B %d, %Y"),
}
best_loss = model_loss
bestmodel_path = 'best_model.pth.tar'
bestmodel_path = "best_model.pth.tar"
bestmodel_path = os.path.join(out_path, bestmodel_path)
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(
model_loss, bestmodel_path))
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
torch.save(state, bestmodel_path)
return best_loss
def check_config_speaker_encoder(c):
"""Check the config.json file of the speaker encoder"""
check_argument('run_name', c, restricted=True, val_type=str)
check_argument('run_description', c, val_type=str)
check_argument("run_name", c, restricted=True, val_type=str)
check_argument("run_description", c, val_type=str)
# audio processing parameters
check_argument('audio', c, restricted=True, val_type=dict)
check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
check_argument("audio", c, restricted=True, val_type=dict)
check_argument("num_mels", c["audio"], restricted=True, val_type=int, min_val=10, max_val=2056)
check_argument("fft_size", c["audio"], restricted=True, val_type=int, min_val=128, max_val=4058)
check_argument("sample_rate", c["audio"], restricted=True, val_type=int, min_val=512, max_val=100000)
check_argument(
"frame_length_ms",
c["audio"],
restricted=True,
val_type=float,
min_val=10,
max_val=1000,
alternative="win_length",
)
check_argument(
"frame_shift_ms", c["audio"], restricted=True, val_type=float, min_val=1, max_val=1000, alternative="hop_length"
)
check_argument("preemphasis", c["audio"], restricted=True, val_type=float, min_val=0, max_val=1)
check_argument("min_level_db", c["audio"], restricted=True, val_type=int, min_val=-1000, max_val=10)
check_argument("ref_level_db", c["audio"], restricted=True, val_type=int, min_val=0, max_val=1000)
check_argument("power", c["audio"], restricted=True, val_type=float, min_val=1, max_val=5)
check_argument("griffin_lim_iters", c["audio"], restricted=True, val_type=int, min_val=10, max_val=1000)
# training parameters
check_argument('loss', c, enum_list=['ge2e', 'angleproto'], restricted=True, val_type=str)
check_argument('grad_clip', c, restricted=True, val_type=float)
check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
check_argument('lr', c, restricted=True, val_type=float, min_val=0)
check_argument('lr_decay', c, restricted=True, val_type=bool)
check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
check_argument('num_speakers_in_batch', c, restricted=True, val_type=int)
check_argument('num_loader_workers', c, restricted=True, val_type=int)
check_argument('wd', c, restricted=True, val_type=float, min_val=0.0, max_val=1.0)
check_argument("loss", c, enum_list=["ge2e", "angleproto"], restricted=True, val_type=str)
check_argument("grad_clip", c, restricted=True, val_type=float)
check_argument("epochs", c, restricted=True, val_type=int, min_val=1)
check_argument("lr", c, restricted=True, val_type=float, min_val=0)
check_argument("lr_decay", c, restricted=True, val_type=bool)
check_argument("warmup_steps", c, restricted=True, val_type=int, min_val=0)
check_argument("tb_model_param_stats", c, restricted=True, val_type=bool)
check_argument("num_speakers_in_batch", c, restricted=True, val_type=int)
check_argument("num_loader_workers", c, restricted=True, val_type=int)
check_argument("wd", c, restricted=True, val_type=float, min_val=0.0, max_val=1.0)
# checkpoint and output parameters
check_argument('steps_plot_stats', c, restricted=True, val_type=int)
check_argument('checkpoint', c, restricted=True, val_type=bool)
check_argument('save_step', c, restricted=True, val_type=int)
check_argument('print_step', c, restricted=True, val_type=int)
check_argument('output_path', c, restricted=True, val_type=str)
check_argument("steps_plot_stats", c, restricted=True, val_type=int)
check_argument("checkpoint", c, restricted=True, val_type=bool)
check_argument("save_step", c, restricted=True, val_type=int)
check_argument("print_step", c, restricted=True, val_type=int)
check_argument("output_path", c, restricted=True, val_type=str)
# model parameters
check_argument('model', c, restricted=True, val_type=dict)
check_argument('input_dim', c['model'], restricted=True, val_type=int)
check_argument('proj_dim', c['model'], restricted=True, val_type=int)
check_argument('lstm_dim', c['model'], restricted=True, val_type=int)
check_argument('num_lstm_layers', c['model'], restricted=True, val_type=int)
check_argument('use_lstm_with_projection', c['model'], restricted=True, val_type=bool)
check_argument("model", c, restricted=True, val_type=dict)
check_argument("input_dim", c["model"], restricted=True, val_type=int)
check_argument("proj_dim", c["model"], restricted=True, val_type=int)
check_argument("lstm_dim", c["model"], restricted=True, val_type=int)
check_argument("num_lstm_layers", c["model"], restricted=True, val_type=int)
check_argument("use_lstm_with_projection", c["model"], restricted=True, val_type=bool)
# in-memory storage parameters
check_argument('storage', c, restricted=True, val_type=dict)
check_argument('sample_from_storage_p', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0)
check_argument('storage_size', c['storage'], restricted=True, val_type=int, min_val=1, max_val=100)
check_argument('additive_noise', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0)
check_argument("storage", c, restricted=True, val_type=dict)
check_argument("sample_from_storage_p", c["storage"], restricted=True, val_type=float, min_val=0.0, max_val=1.0)
check_argument("storage_size", c["storage"], restricted=True, val_type=int, min_val=1, max_val=100)
check_argument("additive_noise", c["storage"], restricted=True, val_type=float, min_val=0.0, max_val=1.0)
# datasets - checking only the first entry
check_argument('datasets', c, restricted=True, val_type=list)
for dataset_entry in c['datasets']:
check_argument('name', dataset_entry, restricted=True, val_type=str)
check_argument('path', dataset_entry, restricted=True, val_type=str)
check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
check_argument("datasets", c, restricted=True, val_type=list)
for dataset_entry in c["datasets"]:
check_argument("name", dataset_entry, restricted=True, val_type=str)
check_argument("path", dataset_entry, restricted=True, val_type=str)
check_argument("meta_file_train", dataset_entry, restricted=True, val_type=[str, list])
check_argument("meta_file_val", dataset_entry, restricted=True, val_type=str)

View File

@ -17,7 +17,7 @@
# Only support eager mode and TF>=2.0.0
# pylint: disable=no-member, invalid-name, relative-beyond-top-level
# pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes
''' voxceleb 1 & 2 '''
""" voxceleb 1 & 2 """
import os
import sys
@ -32,40 +32,38 @@ import soundfile as sf
gfile = tf.compat.v1.gfile
SUBSETS = {
"vox1_dev_wav":
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa",
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab",
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac",
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad"],
"vox1_test_wav":
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"],
"vox2_dev_aac":
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa",
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab",
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac",
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad",
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae",
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf",
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag",
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah"],
"vox2_test_aac":
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"]
"vox1_dev_wav": [
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa",
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab",
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac",
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad",
],
"vox1_test_wav": ["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"],
"vox2_dev_aac": [
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa",
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab",
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac",
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad",
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae",
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf",
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag",
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah",
],
"vox2_test_aac": ["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"],
}
MD5SUM = {
"vox1_dev_wav": "ae63e55b951748cc486645f532ba230b",
"vox2_dev_aac": "bbc063c46078a602ca71605645c2a402",
"vox1_test_wav": "185fdc63c3c739954633d50379a3d102",
"vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312"
"vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312",
}
USER = {
"user": "",
"password": ""
}
USER = {"user": "", "password": ""}
speaker_id_dict = {}
def download_and_extract(directory, subset, urls):
"""Download and extract the given split of dataset.
@ -83,31 +81,30 @@ def download_and_extract(directory, subset, urls):
if os.path.exists(zip_filepath):
continue
logging.info("Downloading %s to %s" % (url, zip_filepath))
subprocess.call('wget %s --user %s --password %s -O %s' %
(url, USER["user"], USER["password"], zip_filepath), shell=True)
subprocess.call(
"wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath),
shell=True,
)
statinfo = os.stat(zip_filepath)
logging.info(
"Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)
)
logging.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size))
# concatenate all parts into zip files
if ".zip" not in zip_filepath:
zip_filepath = "_".join(zip_filepath.split("_")[:-1])
subprocess.call('cat %s* > %s.zip' %
(zip_filepath, zip_filepath), shell=True)
subprocess.call("cat %s* > %s.zip" % (zip_filepath, zip_filepath), shell=True)
zip_filepath += ".zip"
extract_path = zip_filepath.strip(".zip")
# check zip file md5sum
md5 = hashlib.md5(open(zip_filepath, 'rb').read()).hexdigest()
md5 = hashlib.md5(open(zip_filepath, "rb").read()).hexdigest()
if md5 != MD5SUM[subset]:
raise ValueError("md5sum of %s mismatch" % zip_filepath)
with zipfile.ZipFile(zip_filepath, "r") as zfile:
zfile.extractall(directory)
extract_path_ori = os.path.join(directory, zfile.infolist()[0].filename)
subprocess.call('mv %s %s' % (extract_path_ori, extract_path), shell=True)
subprocess.call("mv %s %s" % (extract_path_ori, extract_path), shell=True)
finally:
# gfile.Remove(zip_filepath)
pass
@ -148,8 +145,7 @@ def decode_aac_with_ffmpeg(aac_file, wav_file):
return True
def convert_audio_and_make_label(input_dir, subset,
output_dir, output_file):
def convert_audio_and_make_label(input_dir, subset, output_dir, output_file):
"""Optionally convert AAC to WAV and make speaker labels.
Args:
input_dir: the directory which holds the input dataset.
@ -167,7 +163,7 @@ def convert_audio_and_make_label(input_dir, subset,
for filename in filenames:
name, ext = os.path.splitext(filename)
if ext.lower() == ".wav":
_, ext2 = (os.path.splitext(name))
_, ext2 = os.path.splitext(name)
if ext2:
continue
wav_file = os.path.join(root, filename)
@ -186,15 +182,12 @@ def convert_audio_and_make_label(input_dir, subset,
speaker_id_dict[speaker_name] = num
# wav_filesize = os.path.getsize(wav_file)
wav_length = len(sf.read(wav_file)[0])
files.append(
(os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name)
)
files.append((os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name))
# Write to CSV file which contains four columns:
# "wav_filename", "wav_length_ms", "speaker_id", "speaker_name".
csv_file_path = os.path.join(output_dir, output_file)
df = pandas.DataFrame(
data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"])
df = pandas.DataFrame(data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"])
df.to_csv(csv_file_path, index=False, sep="\t")
logging.info("Successfully generated csv file {}".format(csv_file_path))
@ -205,19 +198,14 @@ def processor(directory, subset, force_process):
if subset not in urls:
raise ValueError(subset, "is not in voxceleb")
subset_csv = os.path.join(directory, subset + '.csv')
subset_csv = os.path.join(directory, subset + ".csv")
if not force_process and os.path.exists(subset_csv):
return subset_csv
logging.info("Downloading and process the voxceleb in %s", directory)
logging.info("Preparing subset %s", subset)
download_and_extract(directory, subset, urls[subset])
convert_audio_and_make_label(
directory,
subset,
directory,
subset + ".csv"
)
convert_audio_and_make_label(directory, subset, directory, subset + ".csv")
logging.info("Finished downloading and processing")
return subset_csv

View File

@ -7,31 +7,31 @@ import numpy as np
import torch
import tqdm
from torch.utils.data import Dataset
from TTS.tts.utils.data import (prepare_data, prepare_stop_target,
prepare_tensor)
from TTS.tts.utils.text import (pad_with_eos_bos, phoneme_to_sequence,
text_to_sequence)
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
class MyDataset(Dataset):
def __init__(self,
outputs_per_step,
text_cleaner,
compute_linear_spec,
ap,
meta_data,
tp=None,
add_blank=False,
batch_group_size=0,
min_seq_len=0,
max_seq_len=float("inf"),
use_phonemes=True,
phoneme_cache_path=None,
phoneme_language="en-us",
enable_eos_bos=False,
speaker_mapping=None,
use_noise_augment=False,
verbose=False):
def __init__(
self,
outputs_per_step,
text_cleaner,
compute_linear_spec,
ap,
meta_data,
tp=None,
add_blank=False,
batch_group_size=0,
min_seq_len=0,
max_seq_len=float("inf"),
use_phonemes=True,
phoneme_cache_path=None,
phoneme_language="en-us",
enable_eos_bos=False,
speaker_mapping=None,
use_noise_augment=False,
verbose=False,
):
"""
Args:
outputs_per_step (int): number of time frames predicted per step.
@ -53,7 +53,7 @@ class MyDataset(Dataset):
use_noise_augment (bool): enable adding random noise to wav for augmentation.
verbose (bool): print diagnostic information.
"""
super(MyDataset, self).__init__()
super().__init__()
self.batch_group_size = batch_group_size
self.items = meta_data
self.outputs_per_step = outputs_per_step
@ -88,45 +88,42 @@ class MyDataset(Dataset):
@staticmethod
def load_np(filename):
data = np.load(filename).astype('float32')
data = np.load(filename).astype("float32")
return data
@staticmethod
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners,
language, tp, add_blank):
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, tp, add_blank):
"""generate a phoneme sequence from text.
since the usage is for subsequent caching, we never add bos and
eos chars here. Instead we add those dynamically later; based on the
config option."""
phonemes = phoneme_to_sequence(text, [cleaners],
language=language,
enable_eos_bos=False,
tp=tp,
add_blank=add_blank)
phonemes = phoneme_to_sequence(
text, [cleaners], language=language, enable_eos_bos=False, tp=tp, add_blank=add_blank
)
phonemes = np.asarray(phonemes, dtype=np.int32)
np.save(cache_path, phonemes)
return phonemes
@staticmethod
def _load_or_generate_phoneme_sequence(wav_file, text, phoneme_cache_path,
enable_eos_bos, cleaners, language,
tp, add_blank):
def _load_or_generate_phoneme_sequence(
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, tp, add_blank
):
file_name = os.path.splitext(os.path.basename(wav_file))[0]
# different names for normal phonemes and with blank chars.
file_name_ext = '_blanked_phoneme.npy' if add_blank else '_phoneme.npy'
cache_path = os.path.join(phoneme_cache_path,
file_name + file_name_ext)
file_name_ext = "_blanked_phoneme.npy" if add_blank else "_phoneme.npy"
cache_path = os.path.join(phoneme_cache_path, file_name + file_name_ext)
try:
phonemes = np.load(cache_path)
except FileNotFoundError:
phonemes = MyDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, tp, add_blank)
text, cache_path, cleaners, language, tp, add_blank
)
except (ValueError, IOError):
print(" [!] failed loading phonemes for {}. "
"Recomputing.".format(wav_file))
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
phonemes = MyDataset._generate_and_cache_phoneme_sequence(
text, cache_path, cleaners, language, tp, add_blank)
text, cache_path, cleaners, language, tp, add_blank
)
if enable_eos_bos:
phonemes = pad_with_eos_bos(phonemes, tp=tp)
phonemes = np.asarray(phonemes, dtype=np.int32)
@ -150,15 +147,20 @@ class MyDataset(Dataset):
if not self.input_seq_computed:
if self.use_phonemes:
text = self._load_or_generate_phoneme_sequence(
wav_file, text, self.phoneme_cache_path,
self.enable_eos_bos, self.cleaners, self.phoneme_language,
self.tp, self.add_blank)
wav_file,
text,
self.phoneme_cache_path,
self.enable_eos_bos,
self.cleaners,
self.phoneme_language,
self.tp,
self.add_blank,
)
else:
text = np.asarray(text_to_sequence(text, [self.cleaners],
tp=self.tp,
add_blank=self.add_blank),
dtype=np.int32)
text = np.asarray(
text_to_sequence(text, [self.cleaners], tp=self.tp, add_blank=self.add_blank), dtype=np.int32
)
assert text.size > 0, self.items[idx][1]
assert wav.size > 0, self.items[idx][1]
@ -173,12 +175,12 @@ class MyDataset(Dataset):
return self.load_data(100)
sample = {
'text': text,
'wav': wav,
'attn': attn,
'item_idx': self.items[idx][1],
'speaker_name': speaker_name,
'wav_file_name': os.path.basename(wav_file)
"text": text,
"wav": wav,
"attn": attn,
"item_idx": self.items[idx][1],
"speaker_name": speaker_name,
"wav_file_name": os.path.basename(wav_file),
}
return sample
@ -187,8 +189,7 @@ class MyDataset(Dataset):
item = args[0]
func_args = args[1]
text, wav_file, *_ = item
phonemes = MyDataset._load_or_generate_phoneme_sequence(
wav_file, text, *func_args)
phonemes = MyDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
return phonemes
def compute_input_seq(self, num_workers=0):
@ -199,17 +200,19 @@ class MyDataset(Dataset):
print(" | > Computing input sequences ...")
for idx, item in enumerate(tqdm.tqdm(self.items)):
text, *_ = item
sequence = np.asarray(text_to_sequence(
text, [self.cleaners],
tp=self.tp,
add_blank=self.add_blank),
dtype=np.int32)
sequence = np.asarray(
text_to_sequence(text, [self.cleaners], tp=self.tp, add_blank=self.add_blank), dtype=np.int32
)
self.items[idx][0] = sequence
else:
func_args = [
self.phoneme_cache_path, self.enable_eos_bos, self.cleaners,
self.phoneme_language, self.tp, self.add_blank
self.phoneme_cache_path,
self.enable_eos_bos,
self.cleaners,
self.phoneme_language,
self.tp,
self.add_blank,
]
if self.verbose:
print(" | > Computing phonemes ...")
@ -220,10 +223,11 @@ class MyDataset(Dataset):
else:
with Pool(num_workers) as p:
phonemes = list(
tqdm.tqdm(p.imap(MyDataset._phoneme_worker,
[[item, func_args]
for item in self.items]),
total=len(self.items)))
tqdm.tqdm(
p.imap(MyDataset._phoneme_worker, [[item, func_args] for item in self.items]),
total=len(self.items),
)
)
for idx, p in enumerate(phonemes):
self.items[idx][0] = p
@ -255,8 +259,10 @@ class MyDataset(Dataset):
print(" | > Min length sequence: {}".format(np.min(lengths)))
print(" | > Avg length sequence: {}".format(np.mean(lengths)))
print(
" | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}"
.format(self.max_seq_len, self.min_seq_len, len(ignored)))
" | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}".format(
self.max_seq_len, self.min_seq_len, len(ignored)
)
)
print(" | > Batch group size: {}.".format(self.batch_group_size))
def __len__(self):
@ -267,11 +273,11 @@ class MyDataset(Dataset):
def collate_fn(self, batch):
r"""
Perform preprocessing and create a final data batch:
1. Sort batch instances by text-length
2. Convert Audio signal to Spectrograms.
3. PAD sequences wrt r.
4. Load to Torch.
Perform preprocessing and create a final data batch:
1. Sort batch instances by text-length
2. Convert Audio signal to Spectrograms.
3. PAD sequences wrt r.
4. Load to Torch.
"""
# Puts each data field into a tensor with outer dimension batch size
@ -280,44 +286,29 @@ class MyDataset(Dataset):
text_lenghts = np.array([len(d["text"]) for d in batch])
# sort items with text input length for RNN efficiency
text_lenghts, ids_sorted_decreasing = torch.sort(
torch.LongTensor(text_lenghts), dim=0, descending=True)
text_lenghts, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lenghts), dim=0, descending=True)
wav = [batch[idx]['wav'] for idx in ids_sorted_decreasing]
item_idxs = [
batch[idx]['item_idx'] for idx in ids_sorted_decreasing
]
text = [batch[idx]['text'] for idx in ids_sorted_decreasing]
wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing]
item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing]
text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
speaker_name = [
batch[idx]['speaker_name'] for idx in ids_sorted_decreasing
]
speaker_name = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
# get speaker embeddings
if self.speaker_mapping is not None:
wav_files_names = [
batch[idx]['wav_file_name']
for idx in ids_sorted_decreasing
]
speaker_embedding = [
self.speaker_mapping[w]['embedding']
for w in wav_files_names
]
wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing]
speaker_embedding = [self.speaker_mapping[w]["embedding"] for w in wav_files_names]
else:
speaker_embedding = None
# compute features
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
mel = [self.ap.melspectrogram(w).astype("float32") for w in wav]
mel_lengths = [m.shape[1] for m in mel]
# compute 'stop token' targets
stop_targets = [
np.array([0.] * (mel_len - 1) + [1.])
for mel_len in mel_lengths
]
stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths]
# PAD stop targets
stop_targets = prepare_stop_target(stop_targets,
self.outputs_per_step)
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
# PAD sequences with longest instance in the batch
text = prepare_data(text).astype(np.int32)
@ -340,9 +331,7 @@ class MyDataset(Dataset):
# compute linear spectrogram
if self.compute_linear_spec:
linear = [
self.ap.spectrogram(w).astype('float32') for w in wav
]
linear = [self.ap.spectrogram(w).astype("float32") for w in wav]
linear = prepare_tensor(linear, self.outputs_per_step)
linear = linear.transpose(0, 2, 1)
assert mel.shape[1] == linear.shape[1]
@ -351,8 +340,8 @@ class MyDataset(Dataset):
linear = None
# collate attention alignments
if batch[0]['attn'] is not None:
attns = [batch[idx]['attn'].T for idx in ids_sorted_decreasing]
if batch[0]["attn"] is not None:
attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing]
for idx, attn in enumerate(attns):
pad2 = mel.shape[1] - attn.shape[1]
pad1 = text.shape[1] - attn.shape[0]
@ -362,8 +351,24 @@ class MyDataset(Dataset):
attns = torch.FloatTensor(attns).unsqueeze(1)
else:
attns = None
return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \
stop_targets, item_idxs, speaker_embedding, attns
return (
text,
text_lenghts,
speaker_name,
linear,
mel,
mel_lengths,
stop_targets,
item_idxs,
speaker_embedding,
attns,
)
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
found {}".format(type(batch[0]))))
raise TypeError(
(
"batch must contain tensors, numbers, dicts or lists;\
found {}".format(
type(batch[0])
)
)
)

View File

@ -13,14 +13,15 @@ from TTS.tts.utils.generic_utils import split_dataset
# UTILITIES
####################
def load_meta_data(datasets, eval_split=True):
meta_data_train_all = []
meta_data_eval_all = [] if eval_split else None
for dataset in datasets:
name = dataset['name']
root_path = dataset['path']
meta_file_train = dataset['meta_file_train']
meta_file_val = dataset['meta_file_val']
name = dataset["name"]
root_path = dataset["path"]
meta_file_train = dataset["meta_file_train"]
meta_file_val = dataset["meta_file_val"]
# setup the right data processor
preprocessor = get_preprocessor_by_name(name)
# load train set
@ -35,8 +36,8 @@ def load_meta_data(datasets, eval_split=True):
meta_data_eval_all += meta_data_eval
meta_data_train_all += meta_data_train
# load attention masks for duration predictor training
if 'meta_file_attn_mask' in dataset and dataset['meta_file_attn_mask'] is not None:
meta_data = dict(load_attention_mask_meta_data(dataset['meta_file_attn_mask']))
if "meta_file_attn_mask" in dataset and dataset["meta_file_attn_mask"] is not None:
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
for idx, ins in enumerate(meta_data_train_all):
attn_file = meta_data[ins[1]].strip()
meta_data_train_all[idx].append(attn_file)
@ -49,12 +50,12 @@ def load_meta_data(datasets, eval_split=True):
def load_attention_mask_meta_data(metafile_path):
"""Load meta data file created by compute_attention_masks.py"""
with open(metafile_path, 'r') as f:
with open(metafile_path, "r") as f:
lines = f.readlines()
meta_data = []
for line in lines:
wav_file, attn_file = line.split('|')
wav_file, attn_file = line.split("|")
meta_data.append([wav_file, attn_file])
return meta_data
@ -69,6 +70,7 @@ def get_preprocessor_by_name(name):
# DATASETS
########################
def tweb(root_path, meta_file):
"""Normalize TWEB dataset.
https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset
@ -76,10 +78,10 @@ def tweb(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "tweb"
with open(txt_file, 'r') as ttf:
with open(txt_file, "r") as ttf:
for line in ttf:
cols = line.split('\t')
wav_file = os.path.join(root_path, cols[0] + '.wav')
cols = line.split("\t")
wav_file = os.path.join(root_path, cols[0] + ".wav")
text = cols[1]
items.append([text, wav_file, speaker_name])
return items
@ -90,9 +92,9 @@ def mozilla(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "mozilla"
with open(txt_file, 'r') as ttf:
with open(txt_file, "r") as ttf:
for line in ttf:
cols = line.split('|')
cols = line.split("|")
wav_file = cols[1].strip()
text = cols[0].strip()
wav_file = os.path.join(root_path, "wavs", wav_file)
@ -105,9 +107,9 @@ def mozilla_de(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "mozilla"
with open(txt_file, 'r', encoding="ISO 8859-1") as ttf:
with open(txt_file, "r", encoding="ISO 8859-1") as ttf:
for line in ttf:
cols = line.strip().split('|')
cols = line.strip().split("|")
wav_file = cols[0].strip()
text = cols[1].strip()
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
@ -118,8 +120,7 @@ def mozilla_de(root_path, meta_file):
def mailabs(root_path, meta_files=None):
"""Normalizes M-AI-Labs meta data files to TTS format"""
speaker_regex = re.compile(
"by_book/(male|female)/(?P<speaker_name>[^/]+)/")
speaker_regex = re.compile("by_book/(male|female)/(?P<speaker_name>[^/]+)/")
if meta_files is None:
csv_files = glob(root_path + "/**/metadata.csv", recursive=True)
else:
@ -135,21 +136,18 @@ def mailabs(root_path, meta_files=None):
continue
speaker_name = speaker_name_match.group("speaker_name")
print(" | > {}".format(csv_file))
with open(txt_file, 'r') as ttf:
with open(txt_file, "r") as ttf:
for line in ttf:
cols = line.split('|')
cols = line.split("|")
if meta_files is None:
wav_file = os.path.join(folder, 'wavs', cols[0] + '.wav')
wav_file = os.path.join(folder, "wavs", cols[0] + ".wav")
else:
wav_file = os.path.join(root_path,
folder.replace("metadata.csv", ""),
'wavs', cols[0] + '.wav')
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
if os.path.isfile(wav_file):
text = cols[1].strip()
items.append([text, wav_file, speaker_name])
else:
raise RuntimeError("> File %s does not exist!" %
(wav_file))
raise RuntimeError("> File %s does not exist!" % (wav_file))
return items
@ -159,10 +157,10 @@ def ljspeech(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "ljspeech"
with open(txt_file, 'r', encoding="utf-8") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split('|')
wav_file = os.path.join(root_path, 'wavs', cols[0] + '.wav')
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
text = cols[1]
items.append([text, wav_file, speaker_name])
return items
@ -171,15 +169,15 @@ def ljspeech(root_path, meta_file):
def sam_accenture(root_path, meta_file):
"""Normalizes the sam-accenture meta data file to TTS format
https://github.com/Sam-Accenture-Non-Binary-Voice/non-binary-voice-files"""
xml_file = os.path.join(root_path, 'voice_over_recordings', meta_file)
xml_file = os.path.join(root_path, "voice_over_recordings", meta_file)
xml_root = ET.parse(xml_file).getroot()
items = []
speaker_name = "sam_accenture"
for item in xml_root.findall('./fileid'):
for item in xml_root.findall("./fileid"):
text = item.text
wav_file = os.path.join(root_path, 'vo_voice_quality_transformation', item.get('id')+'.wav')
wav_file = os.path.join(root_path, "vo_voice_quality_transformation", item.get("id") + ".wav")
if not os.path.exists(wav_file):
print(f' [!] {wav_file} in metafile does not exist. Skipping...')
print(f" [!] {wav_file} in metafile does not exist. Skipping...")
continue
items.append([text, wav_file, speaker_name])
return items
@ -191,10 +189,10 @@ def ruslan(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "ljspeech"
with open(txt_file, 'r', encoding="utf-8") as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split('|')
wav_file = os.path.join(root_path, 'RUSLAN', cols[0] + '.wav')
cols = line.split("|")
wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav")
text = cols[1]
items.append([text, wav_file, speaker_name])
return items
@ -205,9 +203,9 @@ def css10(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "ljspeech"
with open(txt_file, 'r') as ttf:
with open(txt_file, "r") as ttf:
for line in ttf:
cols = line.split('|')
cols = line.split("|")
wav_file = os.path.join(root_path, cols[0])
text = cols[1]
items.append([text, wav_file, speaker_name])
@ -219,10 +217,10 @@ def nancy(root_path, meta_file):
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "nancy"
with open(txt_file, 'r') as ttf:
with open(txt_file, "r") as ttf:
for line in ttf:
utt_id = line.split()[1]
text = line[line.find('"') + 1:line.rfind('"') - 1]
text = line[line.find('"') + 1 : line.rfind('"') - 1]
wav_file = os.path.join(root_path, "wavn", utt_id + ".wav")
items.append([text, wav_file, speaker_name])
return items
@ -232,7 +230,7 @@ def common_voice(root_path, meta_file):
"""Normalize the common voice meta data file to TTS format."""
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, 'r') as ttf:
with open(txt_file, "r") as ttf:
for line in ttf:
if line.startswith("client_id"):
continue
@ -240,7 +238,7 @@ def common_voice(root_path, meta_file):
text = cols[2]
speaker_name = cols[0]
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
items.append([text, wav_file, 'MCV_' + speaker_name])
items.append([text, wav_file, "MCV_" + speaker_name])
return items
@ -250,19 +248,18 @@ def libri_tts(root_path, meta_files=None):
if meta_files is None:
meta_files = glob(f"{root_path}/**/*trans.tsv", recursive=True)
for meta_file in meta_files:
_meta_file = os.path.basename(meta_file).split('.')[0]
speaker_name = _meta_file.split('_')[0]
chapter_id = _meta_file.split('_')[1]
_meta_file = os.path.basename(meta_file).split(".")[0]
speaker_name = _meta_file.split("_")[0]
chapter_id = _meta_file.split("_")[1]
_root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}")
with open(meta_file, 'r') as ttf:
with open(meta_file, "r") as ttf:
for line in ttf:
cols = line.split('\t')
wav_file = os.path.join(_root_path, cols[0] + '.wav')
cols = line.split("\t")
wav_file = os.path.join(_root_path, cols[0] + ".wav")
text = cols[1]
items.append([text, wav_file, 'LTTS_' + speaker_name])
items.append([text, wav_file, "LTTS_" + speaker_name])
for item in items:
assert os.path.exists(
item[1]), f" [!] wav files don't exist - {item[1]}"
assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}"
return items
@ -271,11 +268,10 @@ def custom_turkish(root_path, meta_file):
items = []
speaker_name = "turkish-female"
skipped_files = []
with open(txt_file, 'r', encoding='utf-8') as ttf:
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split('|')
wav_file = os.path.join(root_path, 'wavs',
cols[0].strip() + '.wav')
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0].strip() + ".wav")
if not os.path.exists(wav_file):
skipped_files.append(wav_file)
continue
@ -287,14 +283,14 @@ def custom_turkish(root_path, meta_file):
# ToDo: add the dataset link when the dataset is released publicly
def brspeech(root_path, meta_file):
'''BRSpeech 3.0 beta'''
"""BRSpeech 3.0 beta"""
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, 'r') as ttf:
with open(txt_file, "r") as ttf:
for line in ttf:
if line.startswith("wav_filename"):
continue
cols = line.split('|')
cols = line.split("|")
wav_file = os.path.join(root_path, cols[0])
text = cols[2]
speaker_name = cols[3]
@ -302,45 +298,41 @@ def brspeech(root_path, meta_file):
return items
def vctk(root_path, meta_files=None, wavs_path='wav48'):
def vctk(root_path, meta_files=None, wavs_path="wav48"):
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
test_speakers = meta_files
items = []
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
for meta_file in meta_files:
_, speaker_id, txt_file = os.path.relpath(meta_file,
root_path).split(os.sep)
file_id = txt_file.split('.')[0]
if isinstance(test_speakers,
list): # if is list ignore this speakers ids
_, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep)
file_id = txt_file.split(".")[0]
if isinstance(test_speakers, list): # if is list ignore this speakers ids
if speaker_id in test_speakers:
continue
with open(meta_file) as file_text:
text = file_text.readlines()[0]
wav_file = os.path.join(root_path, wavs_path, speaker_id,
file_id + '.wav')
items.append([text, wav_file, 'VCTK_' + speaker_id])
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
items.append([text, wav_file, "VCTK_" + speaker_id])
return items
def vctk_slim(root_path, meta_files=None, wavs_path='wav48'):
def vctk_slim(root_path, meta_files=None, wavs_path="wav48"):
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
items = []
txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
for text_file in txt_files:
_, speaker_id, txt_file = os.path.relpath(text_file,
root_path).split(os.sep)
file_id = txt_file.split('.')[0]
_, speaker_id, txt_file = os.path.relpath(text_file, root_path).split(os.sep)
file_id = txt_file.split(".")[0]
if isinstance(meta_files, list): # if is list ignore this speakers ids
if speaker_id in meta_files:
continue
wav_file = os.path.join(root_path, wavs_path, speaker_id,
file_id + '.wav')
items.append([None, wav_file, 'VCTK_' + speaker_id])
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
items.append([None, wav_file, "VCTK_" + speaker_id])
return items
# ======================================== VOX CELEB ===========================================
def voxceleb2(root_path, meta_file=None):
"""
@ -365,31 +357,33 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
# if not exists meta file, crawl recursively for 'wav' files
if meta_file is not None:
with open(str(meta_file), 'r') as f:
return [x.strip().split('|') for x in f.readlines()]
with open(str(meta_file), "r") as f:
return [x.strip().split("|") for x in f.readlines()]
elif not cache_to.exists():
cnt = 0
meta_data = []
wav_files = voxceleb_path.rglob("**/*.wav")
for path in tqdm(wav_files, desc=f"Building VoxCeleb {voxcel_idx} Meta file ... this needs to be done only once.",
total=expected_count):
for path in tqdm(
wav_files,
desc=f"Building VoxCeleb {voxcel_idx} Meta file ... this needs to be done only once.",
total=expected_count,
):
speaker_id = str(Path(path).parent.parent.stem)
assert speaker_id.startswith('id')
assert speaker_id.startswith("id")
text = None # VoxCel does not provide transciptions, and they are not needed for training the SE
meta_data.append(f"{text}|{path}|voxcel{voxcel_idx}_{speaker_id}\n")
cnt += 1
with open(str(cache_to), 'w') as f:
with open(str(cache_to), "w") as f:
f.write("".join(meta_data))
if cnt < expected_count:
raise ValueError(f"Found too few instances for Voxceleb. Should be around {expected_count}, is: {cnt}")
with open(str(cache_to), 'r') as f:
return [x.strip().split('|') for x in f.readlines()]
with open(str(cache_to), "r") as f:
return [x.strip().split("|") for x in f.readlines()]
def baker(root_path: str, meta_file: str) -> List[List[str]]:
def baker(root_path: str, meta_file: str) -> List[List[str]]:
"""Normalizes the Baker meta data file to TTS format
Args:
@ -401,9 +395,9 @@ def baker(root_path: str, meta_file: str) -> List[List[str]]:
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "baker"
with open(txt_file, 'r') as ttf:
with open(txt_file, "r") as ttf:
for line in ttf:
wav_name, text = line.rstrip('\n').split("|")
wav_name, text = line.rstrip("\n").split("|")
wav_path = os.path.join(root_path, "clips_22", wav_name)
items.append([text, wav_path, speaker_name])
return items

View File

@ -5,6 +5,7 @@ class MDNBlock(nn.Module):
"""Mixture of Density Network implementation
https://arxiv.org/pdf/2003.01950.pdf
"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.out_channels = out_channels
@ -24,6 +25,6 @@ class MDNBlock(nn.Module):
mu_sigma = self.conv2(o)
# TODO: check this sigmoid
# mu = torch.sigmoid(mu_sigma[:, :self.out_channels//2, :])
mu = mu_sigma[:, :self.out_channels//2, :]
log_sigma = mu_sigma[:, self.out_channels//2:, :]
mu = mu_sigma[:, : self.out_channels // 2, :]
log_sigma = mu_sigma[:, self.out_channels // 2 :, :]
return mu, log_sigma

View File

@ -31,15 +31,16 @@ class WaveNetDecoder(nn.Module):
hidden_channels (int): number of hidden channels for prenet and postnet.
params (dict): dictionary for residual convolutional blocks.
"""
def __init__(self, in_channels, out_channels, hidden_channels, c_in_channels, params):
super().__init__()
# prenet
self.prenet = torch.nn.Conv1d(in_channels, params['hidden_channels'], 1)
self.prenet = torch.nn.Conv1d(in_channels, params["hidden_channels"], 1)
# wavenet layers
self.wn = WNBlocks(params['hidden_channels'], c_in_channels=c_in_channels, **params)
self.wn = WNBlocks(params["hidden_channels"], c_in_channels=c_in_channels, **params)
# postnet
self.postnet = [
torch.nn.Conv1d(params['hidden_channels'], hidden_channels, 1),
torch.nn.Conv1d(params["hidden_channels"], hidden_channels, 1),
torch.nn.ReLU(),
torch.nn.Conv1d(hidden_channels, hidden_channels, 1),
torch.nn.ReLU(),
@ -77,12 +78,12 @@ class RelativePositionTransformerDecoder(nn.Module):
hidden_channels (int): number of hidden channels including Transformer layers.
params (dict): dictionary for residual convolutional blocks.
"""
def __init__(self, in_channels, out_channels, hidden_channels, params):
super().__init__()
self.prenet = Conv1dBN(in_channels, hidden_channels, 1, 1)
self.rel_pos_transformer = RelativePositionTransformer(
in_channels, out_channels, hidden_channels, **params)
self.rel_pos_transformer = RelativePositionTransformer(in_channels, out_channels, hidden_channels, **params)
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
o = self.prenet(x) * x_mask
@ -107,6 +108,7 @@ class FFTransformerDecoder(nn.Module):
hidden_channels (int): number of hidden channels including Transformer layers.
params (dict): dictionary for residual convolutional blocks.
"""
def __init__(self, in_channels, out_channels, params):
super().__init__()
@ -117,7 +119,7 @@ class FFTransformerDecoder(nn.Module):
# TODO: handle multi-speaker
x_mask = 1 if x_mask is None else x_mask
o = self.transformer_block(x) * x_mask
o = self.postnet(o)* x_mask
o = self.postnet(o) * x_mask
return o
@ -141,19 +143,15 @@ class ResidualConv1dBNDecoder(nn.Module):
hidden_channels (int): number of hidden channels including ResidualConv1dBNBlock layers.
params (dict): dictionary for residual convolutional blocks.
"""
def __init__(self, in_channels, out_channels, hidden_channels, params):
super().__init__()
self.res_conv_block = ResidualConv1dBNBlock(in_channels,
hidden_channels,
hidden_channels, **params)
self.res_conv_block = ResidualConv1dBNBlock(in_channels, hidden_channels, hidden_channels, **params)
self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1)
self.postnet = nn.Sequential(
Conv1dBNBlock(hidden_channels,
hidden_channels,
hidden_channels,
params['kernel_size'],
1,
num_conv_blocks=2),
Conv1dBNBlock(
hidden_channels, hidden_channels, hidden_channels, params["kernel_size"], 1, num_conv_blocks=2
),
nn.Conv1d(hidden_channels, out_channels, 1),
)
@ -178,17 +176,18 @@ class Decoder(nn.Module):
# pylint: disable=dangerous-default-value
def __init__(
self,
out_channels,
in_hidden_channels,
decoder_type='residual_conv_bn',
decoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4, 8] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 17
},
c_in_channels=0):
self,
out_channels,
in_hidden_channels,
decoder_type="residual_conv_bn",
decoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4, 8] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 17,
},
c_in_channels=0,
):
super().__init__()
if decoder_type.lower() == "relative_position_transformer":
@ -196,23 +195,27 @@ class Decoder(nn.Module):
in_channels=in_hidden_channels,
out_channels=out_channels,
hidden_channels=in_hidden_channels,
params=decoder_params)
elif decoder_type.lower() == 'residual_conv_bn':
params=decoder_params,
)
elif decoder_type.lower() == "residual_conv_bn":
self.decoder = ResidualConv1dBNDecoder(
in_channels=in_hidden_channels,
out_channels=out_channels,
hidden_channels=in_hidden_channels,
params=decoder_params)
elif decoder_type.lower() == 'wavenet':
self.decoder = WaveNetDecoder(in_channels=in_hidden_channels,
out_channels=out_channels,
hidden_channels=in_hidden_channels,
c_in_channels=c_in_channels,
params=decoder_params)
elif decoder_type.lower() == 'fftransformer':
params=decoder_params,
)
elif decoder_type.lower() == "wavenet":
self.decoder = WaveNetDecoder(
in_channels=in_hidden_channels,
out_channels=out_channels,
hidden_channels=in_hidden_channels,
c_in_channels=c_in_channels,
params=decoder_params,
)
elif decoder_type.lower() == "fftransformer":
self.decoder = FFTransformerDecoder(in_hidden_channels, out_channels, decoder_params)
else:
raise ValueError(f'[!] Unknown decoder type - {decoder_type}')
raise ValueError(f"[!] Unknown decoder type - {decoder_type}")
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
"""

View File

@ -16,16 +16,19 @@ class DurationPredictor(nn.Module):
Args:
hidden_channels (int): number of channels in the inner layers.
"""
def __init__(self, hidden_channels):
super().__init__()
self.layers = nn.ModuleList([
Conv1dBN(hidden_channels, hidden_channels, 4, 1),
Conv1dBN(hidden_channels, hidden_channels, 3, 1),
Conv1dBN(hidden_channels, hidden_channels, 1, 1),
nn.Conv1d(hidden_channels, 1, 1)
])
self.layers = nn.ModuleList(
[
Conv1dBN(hidden_channels, hidden_channels, 4, 1),
Conv1dBN(hidden_channels, hidden_channels, 3, 1),
Conv1dBN(hidden_channels, hidden_channels, 1, 1),
nn.Conv1d(hidden_channels, 1, 1),
]
)
def forward(self, x, x_mask):
"""

View File

@ -1,7 +1,7 @@
from torch import nn
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock
from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock
from TTS.tts.layers.generic.transformer import FFTransformerBlock
@ -16,17 +16,19 @@ class RelativePositionTransformerEncoder(nn.Module):
hidden_channels (int): number of hidden channels
params (dict): dictionary for residual convolutional blocks.
"""
def __init__(self, in_channels, out_channels, hidden_channels, params):
super().__init__()
self.prenet = ResidualConv1dBNBlock(in_channels,
hidden_channels,
hidden_channels,
kernel_size=5,
num_res_blocks=3,
num_conv_blocks=1,
dilations=[1, 1, 1])
self.rel_pos_transformer = RelativePositionTransformer(
hidden_channels, out_channels, hidden_channels, **params)
self.prenet = ResidualConv1dBNBlock(
in_channels,
hidden_channels,
hidden_channels,
kernel_size=5,
num_res_blocks=3,
num_conv_blocks=1,
dilations=[1, 1, 1],
)
self.rel_pos_transformer = RelativePositionTransformer(hidden_channels, out_channels, hidden_channels, **params)
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
if x_mask is None:
@ -47,20 +49,20 @@ class ResidualConv1dBNEncoder(nn.Module):
hidden_channels (int): number of hidden channels
params (dict): dictionary for residual convolutional blocks.
"""
def __init__(self, in_channels, out_channels, hidden_channels, params):
super().__init__()
self.prenet = nn.Sequential(nn.Conv1d(in_channels, hidden_channels, 1),
nn.ReLU())
self.res_conv_block = ResidualConv1dBNBlock(hidden_channels,
hidden_channels,
hidden_channels, **params)
self.prenet = nn.Sequential(nn.Conv1d(in_channels, hidden_channels, 1), nn.ReLU())
self.res_conv_block = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **params)
self.postnet = nn.Sequential(*[
nn.Conv1d(hidden_channels, hidden_channels, 1),
nn.ReLU(),
nn.BatchNorm1d(hidden_channels),
nn.Conv1d(hidden_channels, out_channels, 1)
])
self.postnet = nn.Sequential(
*[
nn.Conv1d(hidden_channels, hidden_channels, 1),
nn.ReLU(),
nn.BatchNorm1d(hidden_channels),
nn.Conv1d(hidden_channels, out_channels, 1),
]
)
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
if x_mask is None:
@ -115,18 +117,15 @@ class Encoder(nn.Module):
}
```
"""
def __init__(
self,
in_hidden_channels,
out_channels,
encoder_type='residual_conv_bn',
encoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 13
},
c_in_channels=0):
self,
in_hidden_channels,
out_channels,
encoder_type="residual_conv_bn",
encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13},
c_in_channels=0,
):
super().__init__()
self.out_channels = out_channels
self.in_channels = in_hidden_channels
@ -137,21 +136,22 @@ class Encoder(nn.Module):
# init encoder
if encoder_type.lower() == "relative_position_transformer":
# text encoder
# pylint: disable=unexpected-keyword-arg
self.encoder = RelativePositionTransformerEncoder(
in_hidden_channels, out_channels, in_hidden_channels,
encoder_params) # pylint: disable=unexpected-keyword-arg
elif encoder_type.lower() == 'residual_conv_bn':
self.encoder = ResidualConv1dBNEncoder(in_hidden_channels,
out_channels,
in_hidden_channels,
encoder_params)
elif encoder_type.lower() == 'fftransformer':
assert in_hidden_channels == out_channels, \
"[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'"
self.encoder = FFTransformerBlock(in_hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg
in_hidden_channels, out_channels, in_hidden_channels, encoder_params
)
elif encoder_type.lower() == "residual_conv_bn":
self.encoder = ResidualConv1dBNEncoder(in_hidden_channels, out_channels, in_hidden_channels, encoder_params)
elif encoder_type.lower() == "fftransformer":
assert (
in_hidden_channels == out_channels
), "[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'"
# pylint: disable=unexpected-keyword-arg
self.encoder = FFTransformerBlock(
in_hidden_channels, **encoder_params
)
else:
raise NotImplementedError(' [!] unknown encoder type.')
raise NotImplementedError(" [!] unknown encoder type.")
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
"""

View File

@ -10,6 +10,7 @@ class GatedConvBlock(nn.Module):
kernel_size (int): convolution kernel size.
dropout_p (float): dropout rate.
"""
def __init__(self, in_out_channels, kernel_size, dropout_p, num_layers):
super().__init__()
# class arguments
@ -20,21 +21,14 @@ class GatedConvBlock(nn.Module):
self.norm_layers = nn.ModuleList()
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.conv_layers += [
nn.Conv1d(in_out_channels,
2 * in_out_channels,
kernel_size,
padding=kernel_size // 2)
]
self.conv_layers += [nn.Conv1d(in_out_channels, 2 * in_out_channels, kernel_size, padding=kernel_size // 2)]
self.norm_layers += [LayerNorm(2 * in_out_channels)]
def forward(self, x, x_mask):
o = x
res = x
for idx in range(self.num_layers):
o = nn.functional.dropout(o,
p=self.dropout_p,
training=self.training)
o = nn.functional.dropout(o, p=self.dropout_p, training=self.training)
o = self.conv_layers[idx](o * x_mask)
o = self.norm_layers[idx](o)
o = nn.functional.glu(o, dim=1)

View File

@ -22,24 +22,17 @@ class LayerNorm(nn.Module):
def forward(self, x):
mean = torch.mean(x, 1, keepdim=True)
variance = torch.mean((x - mean)**2, 1, keepdim=True)
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
x = (x - mean) * torch.rsqrt(variance + self.eps)
x = x * self.gamma + self.beta
return x
class TemporalBatchNorm1d(nn.BatchNorm1d):
"""Normalize each channel separately over time and batch.
"""
def __init__(self,
channels,
affine=True,
track_running_stats=True,
momentum=0.1):
super().__init__(channels,
affine=affine,
track_running_stats=track_running_stats,
momentum=momentum)
"""Normalize each channel separately over time and batch."""
def __init__(self, channels, affine=True, track_running_stats=True, momentum=0.1):
super().__init__(channels, affine=affine, track_running_stats=track_running_stats, momentum=momentum)
def forward(self, x):
return super().forward(x.transpose(2, 1)).transpose(2, 1)
@ -58,6 +51,7 @@ class ActNorm(nn.Module):
- inputs: (B, C, T)
- outputs: (B, C, T)
"""
def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument
super().__init__()
self.channels = channels
@ -68,8 +62,7 @@ class ActNorm(nn.Module):
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument
if x_mask is None:
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device,
dtype=x.dtype)
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
x_len = torch.sum(x_mask, [1, 2])
if not self.initialized:
self.initialize(x, x_mask)
@ -95,13 +88,11 @@ class ActNorm(nn.Module):
denom = torch.sum(x_mask, [0, 2])
m = torch.sum(x * x_mask, [0, 2]) / denom
m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
v = m_sq - (m**2)
v = m_sq - (m ** 2)
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(
dtype=self.bias.dtype)
logs_init = (-logs).view(*self.logs.shape).to(
dtype=self.logs.dtype)
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
self.bias.data.copy_(bias_init)
self.logs.data.copy_(logs_init)

View File

@ -11,20 +11,20 @@ class PositionalEncoding(nn.Module):
channels (int): embedding size
dropout (float): dropout parameter
"""
def __init__(self, channels, dropout_p=0.0, max_len=5000):
super().__init__()
if channels % 2 != 0:
raise ValueError(
"Cannot use sin/cos positional encoding with "
"odd channels (got channels={:d})".format(channels))
"Cannot use sin/cos positional encoding with " "odd channels (got channels={:d})".format(channels)
)
pe = torch.zeros(max_len, channels)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.pow(10000,
torch.arange(0, channels, 2).float() / channels)
div_term = torch.pow(10000, torch.arange(0, channels, 2).float() / channels)
pe[:, 0::2] = torch.sin(position.float() * div_term)
pe[:, 1::2] = torch.cos(position.float() * div_term)
pe = pe.unsqueeze(0).transpose(1, 2)
self.register_buffer('pe', pe)
self.register_buffer("pe", pe)
if dropout_p > 0:
self.dropout = nn.Dropout(p=dropout_p)
self.channels = channels
@ -43,14 +43,15 @@ class PositionalEncoding(nn.Module):
if self.pe.size(2) < x.size(2):
raise RuntimeError(
f"Sequence is {x.size(2)} but PositionalEncoding is"
f" limited to {self.pe.size(2)}. See max_len argument.")
f" limited to {self.pe.size(2)}. See max_len argument."
)
if mask is not None:
pos_enc = (self.pe[:, :, :x.size(2)] * mask)
pos_enc = self.pe[:, :, : x.size(2)] * mask
else:
pos_enc = self.pe[:, :, :x.size(2)]
pos_enc = self.pe[:, :, : x.size(2)]
x = x + pos_enc
else:
x = x + self.pe[:, :, first_idx:last_idx]
if hasattr(self, 'dropout'):
if hasattr(self, "dropout"):
x = self.dropout(x)
return x

View File

@ -3,9 +3,10 @@ from torch import nn
class ZeroTemporalPad(nn.Module):
"""Pad sequences to equal lentgh in the temporal dimension"""
def __init__(self, kernel_size, dilation):
super().__init__()
total_pad = (dilation * (kernel_size - 1))
total_pad = dilation * (kernel_size - 1)
begin = total_pad // 2
end = total_pad - begin
self.pad_layer = nn.ZeroPad2d((0, 0, begin, end))
@ -27,9 +28,10 @@ class Conv1dBN(nn.Module):
kernel_size (int): kernel size for convolutional filters.
dilation (int): dilation for convolution layers.
"""
def __init__(self, in_channels, out_channels, kernel_size, dilation):
super().__init__()
padding = (dilation * (kernel_size - 1))
padding = dilation * (kernel_size - 1)
pad_s = padding // 2
pad_e = padding - pad_s
self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation)
@ -55,14 +57,17 @@ class Conv1dBNBlock(nn.Module):
dilation (int): dilation for convolution layers.
num_conv_blocks (int, optional): number of convolutional blocks. Defaults to 2.
"""
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation, num_conv_blocks=2):
super().__init__()
self.conv_bn_blocks = []
for idx in range(num_conv_blocks):
layer = Conv1dBN(in_channels if idx == 0 else hidden_channels,
out_channels if idx == (num_conv_blocks - 1) else hidden_channels,
kernel_size,
dilation)
layer = Conv1dBN(
in_channels if idx == 0 else hidden_channels,
out_channels if idx == (num_conv_blocks - 1) else hidden_channels,
kernel_size,
dilation,
)
self.conv_bn_blocks.append(layer)
self.conv_bn_blocks = nn.Sequential(*self.conv_bn_blocks)
@ -91,18 +96,23 @@ class ResidualConv1dBNBlock(nn.Module):
num_res_blocks (int, optional): number of residual blocks. Defaults to 13.
num_conv_blocks (int, optional): number of convolutional blocks in each residual block. Defaults to 2.
"""
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilations, num_res_blocks=13, num_conv_blocks=2):
def __init__(
self, in_channels, out_channels, hidden_channels, kernel_size, dilations, num_res_blocks=13, num_conv_blocks=2
):
super().__init__()
assert len(dilations) == num_res_blocks
self.res_blocks = nn.ModuleList()
for idx, dilation in enumerate(dilations):
block = Conv1dBNBlock(in_channels if idx == 0 else hidden_channels,
out_channels if (idx + 1) == len(dilations) else hidden_channels,
hidden_channels,
kernel_size,
dilation,
num_conv_blocks)
block = Conv1dBNBlock(
in_channels if idx == 0 else hidden_channels,
out_channels if (idx + 1) == len(dilations) else hidden_channels,
hidden_channels,
kernel_size,
dilation,
num_conv_blocks,
)
self.res_blocks.append(block)
def forward(self, x, x_mask=None):

View File

@ -5,12 +5,8 @@ from torch import nn
class TimeDepthSeparableConv(nn.Module):
"""Time depth separable convolution as in https://arxiv.org/pdf/1904.02619.pdf
It shows competative results with less computation and memory footprint."""
def __init__(self,
in_channels,
hid_channels,
out_channels,
kernel_size,
bias=True):
def __init__(self, in_channels, hid_channels, out_channels, kernel_size, bias=True):
super().__init__()
self.in_channels = in_channels
@ -62,28 +58,24 @@ class TimeDepthSeparableConv(nn.Module):
class TimeDepthSeparableConvBlock(nn.Module):
def __init__(self,
in_channels,
hid_channels,
out_channels,
num_layers,
kernel_size,
bias=True):
def __init__(self, in_channels, hid_channels, out_channels, num_layers, kernel_size, bias=True):
super().__init__()
assert (kernel_size - 1) % 2 == 0
assert num_layers > 1
self.layers = nn.ModuleList()
layer = TimeDepthSeparableConv(
in_channels, hid_channels,
out_channels if num_layers == 1 else hid_channels, kernel_size,
bias)
in_channels, hid_channels, out_channels if num_layers == 1 else hid_channels, kernel_size, bias
)
self.layers.append(layer)
for idx in range(num_layers - 1):
layer = TimeDepthSeparableConv(
hid_channels, hid_channels, out_channels if
(idx + 1) == (num_layers - 1) else hid_channels, kernel_size,
bias)
hid_channels,
hid_channels,
out_channels if (idx + 1) == (num_layers - 1) else hid_channels,
kernel_size,
bias,
)
self.layers.append(layer)
def forward(self, x, mask):

View File

@ -4,16 +4,9 @@ import torch.nn.functional as F
class FFTransformer(nn.Module):
def __init__(self,
in_out_channels,
num_heads,
hidden_channels_ffn=1024,
kernel_size_fft=3,
dropout_p=0.1):
def __init__(self, in_out_channels, num_heads, hidden_channels_ffn=1024, kernel_size_fft=3, dropout_p=0.1):
super().__init__()
self.self_attn = nn.MultiheadAttention(in_out_channels,
num_heads,
dropout=dropout_p)
self.self_attn = nn.MultiheadAttention(in_out_channels, num_heads, dropout=dropout_p)
padding = (kernel_size_fft - 1) // 2
self.conv1 = nn.Conv1d(in_out_channels, hidden_channels_ffn, kernel_size=kernel_size_fft, padding=padding)
@ -27,11 +20,7 @@ class FFTransformer(nn.Module):
def forward(self, src, src_mask=None, src_key_padding_mask=None):
"""😦 ugly looking with all the transposing """
src = src.permute(2, 0, 1)
src2, enc_align = self.self_attn(src,
src,
src,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)
src2, enc_align = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
src = self.norm1(src + src2)
# T x B x D -> B x D x T
src = src.permute(1, 2, 0)
@ -45,15 +34,19 @@ class FFTransformer(nn.Module):
class FFTransformerBlock(nn.Module):
def __init__(self, in_out_channels, num_heads, hidden_channels_ffn,
num_layers, dropout_p):
def __init__(self, in_out_channels, num_heads, hidden_channels_ffn, num_layers, dropout_p):
super().__init__()
self.fft_layers = nn.ModuleList([
FFTransformer(in_out_channels=in_out_channels,
num_heads=num_heads,
hidden_channels_ffn=hidden_channels_ffn,
dropout_p=dropout_p) for _ in range(num_layers)
])
self.fft_layers = nn.ModuleList(
[
FFTransformer(
in_out_channels=in_out_channels,
num_heads=num_heads,
hidden_channels_ffn=hidden_channels_ffn,
dropout_p=dropout_p,
)
for _ in range(num_layers)
]
)
def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument
"""

View File

@ -32,15 +32,18 @@ class WN(torch.nn.Module):
dropout_p (float): dropout rate.
weight_norm (bool): enable/disable weight norm for convolution layers.
"""
def __init__(self,
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_layers,
c_in_channels=0,
dropout_p=0,
weight_norm=True):
def __init__(
self,
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_layers,
c_in_channels=0,
dropout_p=0,
weight_norm=True,
):
super().__init__()
assert kernel_size % 2 == 1
assert hidden_channels % 2 == 0
@ -58,20 +61,16 @@ class WN(torch.nn.Module):
# init conditioning layer
if c_in_channels > 0:
cond_layer = torch.nn.Conv1d(c_in_channels,
2 * hidden_channels * num_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer,
name='weight')
cond_layer = torch.nn.Conv1d(c_in_channels, 2 * hidden_channels * num_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
# intermediate layers
for i in range(num_layers):
dilation = dilation_rate**i
dilation = dilation_rate ** i
padding = int((kernel_size * dilation - dilation) / 2)
in_layer = torch.nn.Conv1d(hidden_channels,
2 * hidden_channels,
kernel_size,
dilation=dilation,
padding=padding)
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
in_layer = torch.nn.Conv1d(
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
)
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
self.in_layers.append(in_layer)
if i < num_layers - 1:
@ -79,10 +78,8 @@ class WN(torch.nn.Module):
else:
res_skip_channels = hidden_channels
res_skip_layer = torch.nn.Conv1d(hidden_channels,
res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer,
name='weight')
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
self.res_skip_layers.append(res_skip_layer)
# setup weight norm
if not weight_norm:
@ -99,15 +96,14 @@ class WN(torch.nn.Module):
x_in = self.dropout(x_in)
if g is not None:
cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
else:
g_l = torch.zeros_like(x_in)
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l,
n_channels_tensor)
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
res_skip_acts = self.res_skip_layers[i](acts)
if i < self.num_layers - 1:
x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask
output = output + res_skip_acts[:, self.hidden_channels:, :]
x = (x + res_skip_acts[:, : self.hidden_channels, :]) * x_mask
output = output + res_skip_acts[:, self.hidden_channels :, :]
else:
output = output + res_skip_acts
return output * x_mask
@ -140,28 +136,32 @@ class WNBlocks(nn.Module):
weight_norm (bool): enable/disable weight norm for convolution layers.
"""
def __init__(self,
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_blocks,
num_layers,
c_in_channels=0,
dropout_p=0,
weight_norm=True):
def __init__(
self,
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_blocks,
num_layers,
c_in_channels=0,
dropout_p=0,
weight_norm=True,
):
super().__init__()
self.wn_blocks = nn.ModuleList()
for idx in range(num_blocks):
layer = WN(in_channels=in_channels if idx == 0 else hidden_channels,
hidden_channels=hidden_channels,
kernel_size=kernel_size,
dilation_rate=dilation_rate,
num_layers=num_layers,
c_in_channels=c_in_channels,
dropout_p=dropout_p,
weight_norm=weight_norm)
layer = WN(
in_channels=in_channels if idx == 0 else hidden_channels,
hidden_channels=hidden_channels,
kernel_size=kernel_size,
dilation_rate=dilation_rate,
num_layers=num_layers,
c_in_channels=c_in_channels,
dropout_p=dropout_p,
weight_norm=weight_norm,
)
self.wn_blocks.append(layer)
def forward(self, x, x_mask=None, g=None):

View File

@ -18,14 +18,12 @@ def squeeze(x, x_mask=None, num_sqz=2):
t = (t // num_sqz) * num_sqz
x = x[:, :, :t]
x_sqz = x.view(b, c, t // num_sqz, num_sqz)
x_sqz = x_sqz.permute(0, 3, 1,
2).contiguous().view(b, c * num_sqz, t // num_sqz)
x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * num_sqz, t // num_sqz)
if x_mask is not None:
x_mask = x_mask[:, :, num_sqz - 1::num_sqz]
x_mask = x_mask[:, :, num_sqz - 1 :: num_sqz]
else:
x_mask = torch.ones(b, 1, t // num_sqz).to(device=x.device,
dtype=x.dtype)
x_mask = torch.ones(b, 1, t // num_sqz).to(device=x.device, dtype=x.dtype)
return x_sqz * x_mask, x_mask
@ -34,20 +32,16 @@ def unsqueeze(x, x_mask=None, num_sqz=2):
Note:
each 's' is a n-dimensional vector.
[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]] """
[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]]"""
b, c, t = x.size()
x_unsqz = x.view(b, num_sqz, c // num_sqz, t)
x_unsqz = x_unsqz.permute(0, 2, 3,
1).contiguous().view(b, c // num_sqz,
t * num_sqz)
x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // num_sqz, t * num_sqz)
if x_mask is not None:
x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1,
num_sqz).view(b, 1, t * num_sqz)
x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, num_sqz).view(b, 1, t * num_sqz)
else:
x_mask = torch.ones(b, 1, t * num_sqz).to(device=x.device,
dtype=x.dtype)
x_mask = torch.ones(b, 1, t * num_sqz).to(device=x.device, dtype=x.dtype)
return x_unsqz * x_mask, x_mask
@ -65,18 +59,21 @@ class Decoder(nn.Module):
dropout_p (float): wavenet dropout rate.
sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer.
"""
def __init__(self,
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_flow_blocks,
num_coupling_layers,
dropout_p=0.,
num_splits=4,
num_squeeze=2,
sigmoid_scale=False,
c_in_channels=0):
def __init__(
self,
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_flow_blocks,
num_coupling_layers,
dropout_p=0.0,
num_splits=4,
num_squeeze=2,
sigmoid_scale=False,
c_in_channels=0,
):
super().__init__()
self.in_channels = in_channels
@ -94,18 +91,19 @@ class Decoder(nn.Module):
self.flows = nn.ModuleList()
for _ in range(num_flow_blocks):
self.flows.append(ActNorm(channels=in_channels * num_squeeze))
self.flows.append(InvConvNear(channels=in_channels * num_squeeze, num_splits=num_splits))
self.flows.append(
InvConvNear(channels=in_channels * num_squeeze,
num_splits=num_splits))
self.flows.append(
CouplingBlock(in_channels * num_squeeze,
hidden_channels,
kernel_size=kernel_size,
dilation_rate=dilation_rate,
num_layers=num_coupling_layers,
c_in_channels=c_in_channels,
dropout_p=dropout_p,
sigmoid_scale=sigmoid_scale))
CouplingBlock(
in_channels * num_squeeze,
hidden_channels,
kernel_size=kernel_size,
dilation_rate=dilation_rate,
num_layers=num_coupling_layers,
c_in_channels=c_in_channels,
dropout_p=dropout_p,
sigmoid_scale=sigmoid_scale,
)
)
def forward(self, x, x_mask, g=None, reverse=False):
if not reverse:

View File

@ -14,6 +14,7 @@ class DurationPredictor(nn.Module):
kernel_size ([type]): [description]
dropout_p ([type]): [description]
"""
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p):
super().__init__()
# class arguments
@ -23,15 +24,9 @@ class DurationPredictor(nn.Module):
self.dropout_p = dropout_p
# layers
self.drop = nn.Dropout(dropout_p)
self.conv_1 = nn.Conv1d(in_channels,
hidden_channels,
kernel_size,
padding=kernel_size // 2)
self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = LayerNorm(hidden_channels)
self.conv_2 = nn.Conv1d(hidden_channels,
hidden_channels,
kernel_size,
padding=kernel_size // 2)
self.conv_2 = nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = LayerNorm(hidden_channels)
# output layer
self.proj = nn.Conv1d(hidden_channels, 1, 1)

View File

@ -69,17 +69,20 @@ class Encoder(nn.Module):
'num_layers': 9,
}
"""
def __init__(self,
num_chars,
out_channels,
hidden_channels,
hidden_channels_dp,
encoder_type,
encoder_params,
dropout_p_dp=0.1,
mean_only=False,
use_prenet=True,
c_in_channels=0):
def __init__(
self,
num_chars,
out_channels,
hidden_channels,
hidden_channels_dp,
encoder_type,
encoder_params,
dropout_p_dp=0.1,
mean_only=False,
use_prenet=True,
c_in_channels=0,
):
super().__init__()
# class arguments
self.num_chars = num_chars
@ -93,47 +96,33 @@ class Encoder(nn.Module):
self.encoder_type = encoder_type
# embedding layer
self.emb = nn.Embedding(num_chars, hidden_channels)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
# init encoder module
if encoder_type.lower() == "rel_pos_transformer":
if use_prenet:
self.prenet = ResidualConv1dLayerNormBlock(hidden_channels,
hidden_channels,
hidden_channels,
kernel_size=5,
num_layers=3,
dropout_p=0.5)
self.encoder = RelativePositionTransformer(hidden_channels,
hidden_channels,
hidden_channels,
**encoder_params)
elif encoder_type.lower() == 'gated_conv':
self.encoder = GatedConvBlock(hidden_channels, **encoder_params)
elif encoder_type.lower() == 'residual_conv_bn':
if use_prenet:
self.prenet = nn.Sequential(
nn.Conv1d(hidden_channels, hidden_channels, 1),
nn.ReLU()
self.prenet = ResidualConv1dLayerNormBlock(
hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5
)
self.encoder = ResidualConv1dBNBlock(hidden_channels,
hidden_channels,
hidden_channels,
**encoder_params)
self.postnet = nn.Sequential(
nn.Conv1d(self.hidden_channels, self.hidden_channels, 1),
nn.BatchNorm1d(self.hidden_channels))
elif encoder_type.lower() == 'time_depth_separable':
self.encoder = RelativePositionTransformer(
hidden_channels, hidden_channels, hidden_channels, **encoder_params
)
elif encoder_type.lower() == "gated_conv":
self.encoder = GatedConvBlock(hidden_channels, **encoder_params)
elif encoder_type.lower() == "residual_conv_bn":
if use_prenet:
self.prenet = ResidualConv1dLayerNormBlock(hidden_channels,
hidden_channels,
hidden_channels,
kernel_size=5,
num_layers=3,
dropout_p=0.5)
self.encoder = TimeDepthSeparableConvBlock(hidden_channels,
hidden_channels,
hidden_channels,
**encoder_params)
self.prenet = nn.Sequential(nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU())
self.encoder = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **encoder_params)
self.postnet = nn.Sequential(
nn.Conv1d(self.hidden_channels, self.hidden_channels, 1), nn.BatchNorm1d(self.hidden_channels)
)
elif encoder_type.lower() == "time_depth_separable":
if use_prenet:
self.prenet = ResidualConv1dLayerNormBlock(
hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5
)
self.encoder = TimeDepthSeparableConvBlock(
hidden_channels, hidden_channels, hidden_channels, **encoder_params
)
else:
raise ValueError(" [!] Unkown encoder type.")
@ -143,8 +132,8 @@ class Encoder(nn.Module):
self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1)
# duration predictor
self.duration_predictor = DurationPredictor(
hidden_channels + c_in_channels, hidden_channels_dp, 3,
dropout_p_dp)
hidden_channels + c_in_channels, hidden_channels_dp, 3, dropout_p_dp
)
def forward(self, x, x_lengths, g=None):
"""
@ -159,15 +148,14 @@ class Encoder(nn.Module):
# [B, D, T]
x = torch.transpose(x, 1, -1)
# compute input sequence mask
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)),
1).to(x.dtype)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
# prenet
if hasattr(self, 'prenet') and self.use_prenet:
if hasattr(self, "prenet") and self.use_prenet:
x = self.prenet(x, x_mask)
# encoder
x = self.encoder(x, x_mask)
# postnet
if hasattr(self, 'postnet'):
if hasattr(self, "postnet"):
x = self.postnet(x) * x_mask
# set duration predictor input
if g is not None:

View File

@ -7,8 +7,7 @@ from ..generic.normalization import LayerNorm
class ResidualConv1dLayerNormBlock(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
num_layers, dropout_p):
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, num_layers, dropout_p):
"""Conv1d with Layer Normalization and residual connection as in GlowTTS paper.
https://arxiv.org/pdf/1811.00002.pdf
@ -38,10 +37,10 @@ class ResidualConv1dLayerNormBlock(nn.Module):
for idx in range(num_layers):
self.conv_layers.append(
nn.Conv1d(in_channels if idx == 0 else hidden_channels,
hidden_channels,
kernel_size,
padding=kernel_size // 2))
nn.Conv1d(
in_channels if idx == 0 else hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
)
self.norm_layers.append(LayerNorm(hidden_channels))
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
@ -72,6 +71,7 @@ class InvConvNear(nn.Module):
perform 1x1 convolution separately. Cast 1x1 conv operation
to 2d by reshaping the input for efficiency.
"""
def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument
super().__init__()
assert num_splits % 2 == 0
@ -80,8 +80,7 @@ class InvConvNear(nn.Module):
self.no_jacobian = no_jacobian
self.weight_inv = None
w_init = torch.qr(
torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0]
w_init = torch.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0]
if torch.det(w_init) < 0:
w_init[:, 0] = -1 * w_init[:, 0]
self.weight = nn.Parameter(w_init)
@ -97,28 +96,25 @@ class InvConvNear(nn.Module):
assert c % self.num_splits == 0
if x_mask is None:
x_mask = 1
x_len = torch.ones((b, ), dtype=x.dtype, device=x.device) * t
x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
else:
x_len = torch.sum(x_mask, [1, 2])
x = x.view(b, 2, c // self.num_splits, self.num_splits // 2, t)
x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.num_splits,
c // self.num_splits, t)
x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.num_splits, c // self.num_splits, t)
if reverse:
if self.weight_inv is not None:
weight = self.weight_inv
else:
weight = torch.inverse(
self.weight.float()).to(dtype=self.weight.dtype)
weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
logdet = None
else:
weight = self.weight
if self.no_jacobian:
logdet = 0
else:
logdet = torch.logdet(
self.weight) * (c / self.num_splits) * x_len # [b]
logdet = torch.logdet(self.weight) * (c / self.num_splits) * x_len # [b]
weight = weight.view(self.num_splits, self.num_splits, 1, 1)
z = F.conv2d(x, weight)
@ -128,40 +124,42 @@ class InvConvNear(nn.Module):
return z, logdet
def store_inverse(self):
weight_inv = torch.inverse(
self.weight.float()).to(dtype=self.weight.dtype)
weight_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
self.weight_inv = nn.Parameter(weight_inv, requires_grad=False)
class CouplingBlock(nn.Module):
"""Glow Affine Coupling block as in GlowTTS paper.
https://arxiv.org/pdf/1811.00002.pdf
https://arxiv.org/pdf/1811.00002.pdf
x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o
'-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^
x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o
'-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^
Args:
in_channels (int): number of input tensor channels.
hidden_channels (int): number of hidden channels.
kernel_size (int): WaveNet filter kernel size.
dilation_rate (int): rate to increase dilation by each layer in a decoder block.
num_layers (int): number of WaveNet layers.
c_in_channels (int): number of conditioning input channels.
dropout_p (int): wavenet dropout rate.
sigmoid_scale (bool): enable/disable sigmoid scaling for output scale.
Args:
in_channels (int): number of input tensor channels.
hidden_channels (int): number of hidden channels.
kernel_size (int): WaveNet filter kernel size.
dilation_rate (int): rate to increase dilation by each layer in a decoder block.
num_layers (int): number of WaveNet layers.
c_in_channels (int): number of conditioning input channels.
dropout_p (int): wavenet dropout rate.
sigmoid_scale (bool): enable/disable sigmoid scaling for output scale.
Note:
It does not use conditional inputs differently from WaveGlow.
Note:
It does not use conditional inputs differently from WaveGlow.
"""
def __init__(self,
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_layers,
c_in_channels=0,
dropout_p=0,
sigmoid_scale=False):
def __init__(
self,
in_channels,
hidden_channels,
kernel_size,
dilation_rate,
num_layers,
c_in_channels=0,
dropout_p=0,
sigmoid_scale=False,
):
super().__init__()
self.in_channels = in_channels
self.hidden_channels = hidden_channels
@ -183,8 +181,7 @@ class CouplingBlock(nn.Module):
end.bias.data.zero_()
self.end = end
# coupling layers
self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate,
num_layers, c_in_channels, dropout_p)
self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels, dropout_p)
def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument
"""
@ -195,15 +192,15 @@ class CouplingBlock(nn.Module):
"""
if x_mask is None:
x_mask = 1
x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
x_0, x_1 = x[:, : self.in_channels // 2], x[:, self.in_channels // 2 :]
x = self.start(x_0) * x_mask
x = self.wn(x, x_mask, g)
out = self.end(x)
z_0 = x_0
t = out[:, :self.in_channels // 2, :]
s = out[:, self.in_channels // 2:, :]
t = out[:, : self.in_channels // 2, :]
s = out[:, self.in_channels // 2 :, :]
if self.sigmoid_scale:
s = torch.log(1e-6 + torch.sigmoid(s + 2))

View File

@ -6,6 +6,7 @@ from TTS.tts.utils.generic_utils import sequence_mask
try:
# TODO: fix pypi cython installation problem.
from TTS.tts.layers.glow_tts.monotonic_align.core import maximum_path_c
CYTHON = True
except ModuleNotFoundError:
CYTHON = False
@ -30,8 +31,7 @@ def generate_path(duration, mask):
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]
]))[:, :-1]
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path * mask
return path
@ -43,7 +43,7 @@ def maximum_path(value, mask):
def maximum_path_cython(value, mask):
""" Cython optimised version.
"""Cython optimised version.
value: [b, t_x, t_y]
mask: [b, t_x, t_y]
"""

View File

@ -48,16 +48,19 @@ class RelativePositionMultiHeadAttention(nn.Module):
proximal_init (bool, optional): enable/disable poximal init as in the paper.
Init key and query layer weights the same. Defaults to False.
"""
def __init__(self,
channels,
out_channels,
num_heads,
rel_attn_window_size=None,
heads_share=True,
dropout_p=0.,
input_length=None,
proximal_bias=False,
proximal_init=False):
def __init__(
self,
channels,
out_channels,
num_heads,
rel_attn_window_size=None,
heads_share=True,
dropout_p=0.0,
input_length=None,
proximal_bias=False,
proximal_init=False,
):
super().__init__()
assert channels % num_heads == 0, " [!] channels should be divisible by num_heads."
@ -82,15 +85,15 @@ class RelativePositionMultiHeadAttention(nn.Module):
# relative positional encoding layers
if rel_attn_window_size is not None:
n_heads_rel = 1 if heads_share else num_heads
rel_stddev = self.k_channels**-0.5
rel_stddev = self.k_channels ** -0.5
emb_rel_k = nn.Parameter(
torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1,
self.k_channels) * rel_stddev)
torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev
)
emb_rel_v = nn.Parameter(
torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1,
self.k_channels) * rel_stddev)
self.register_parameter('emb_rel_k', emb_rel_k)
self.register_parameter('emb_rel_v', emb_rel_v)
torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev
)
self.register_parameter("emb_rel_k", emb_rel_k)
self.register_parameter("emb_rel_v", emb_rel_v)
# init layers
nn.init.xavier_uniform_(self.conv_q.weight)
@ -112,38 +115,30 @@ class RelativePositionMultiHeadAttention(nn.Module):
def attention(self, query, key, value, mask=None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, t_t = (*key.size(), query.size(2))
query = query.view(b, self.num_heads, self.k_channels,
t_t).transpose(2, 3)
query = query.view(b, self.num_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.num_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.num_heads, self.k_channels,
t_s).transpose(2, 3)
value = value.view(b, self.num_heads, self.k_channels, t_s).transpose(2, 3)
# compute raw attention scores
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
self.k_channels)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
# relative positional encoding for scores
if self.rel_attn_window_size is not None:
assert t_s == t_t, "Relative attention is only available for self-attention."
# get relative key embeddings
key_relative_embeddings = self._get_relative_embeddings(
self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(
query, key_relative_embeddings)
rel_logits = self._relative_position_to_absolute_position(
rel_logits)
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
rel_logits = self._relative_position_to_absolute_position(rel_logits)
scores_local = rel_logits / math.sqrt(self.k_channels)
scores = scores + scores_local
# proximan bias
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attn_proximity_bias(t_s).to(
device=scores.device, dtype=scores.dtype)
scores = scores + self._attn_proximity_bias(t_s).to(device=scores.device, dtype=scores.dtype)
# attention score masking
if mask is not None:
# add small value to prevent oor error.
scores = scores.masked_fill(mask == 0, -1e4)
if self.input_length is not None:
block_mask = torch.ones_like(scores).triu(
-1 * self.input_length).tril(self.input_length)
block_mask = torch.ones_like(scores).triu(-1 * self.input_length).tril(self.input_length)
scores = scores * block_mask + -1e4 * (1 - block_mask)
# attention score normalization
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
@ -153,14 +148,10 @@ class RelativePositionMultiHeadAttention(nn.Module):
output = torch.matmul(p_attn, value)
# relative positional encoding for values
if self.rel_attn_window_size is not None:
relative_weights = self._absolute_position_to_relative_position(
p_attn)
value_relative_embeddings = self._get_relative_embeddings(
self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(
relative_weights, value_relative_embeddings)
output = output.transpose(2, 3).contiguous().view(
b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn
@staticmethod
@ -195,20 +186,16 @@ class RelativePositionMultiHeadAttention(nn.Module):
return logits
def _get_relative_embeddings(self, relative_embeddings, length):
"""Convert embedding vestors to a tensor of embeddings
"""
"""Convert embedding vestors to a tensor of embeddings"""
# Pad first before slice to avoid using cond ops.
pad_length = max(length - (self.rel_attn_window_size + 1), 0)
slice_start_position = max((self.rel_attn_window_size + 1) - length, 0)
slice_end_position = slice_start_position + 2 * length - 1
if pad_length > 0:
padded_relative_embeddings = F.pad(
relative_embeddings, [0, 0, pad_length, pad_length, 0, 0])
padded_relative_embeddings = F.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0])
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[:,
slice_start_position:
slice_end_position]
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
return used_relative_embeddings
@staticmethod
@ -226,8 +213,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(x_flat, [0, length - 1, 0, 0, 0, 0])
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1,
2 * length - 1])[:, :, :length, length - 1:]
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
return x_final
@staticmethod
@ -239,7 +225,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0])
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0])
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
@ -267,19 +253,15 @@ class RelativePositionMultiHeadAttention(nn.Module):
class FeedForwardNetwork(nn.Module):
"""Feed Forward Inner layers for Transformer.
Args:
in_channels (int): input tensor channels.
out_channels (int): output tensor channels.
hidden_channels (int): inner layers hidden channels.
kernel_size (int): conv1d filter kernel size.
dropout_p (float, optional): dropout rate. Defaults to 0.
Args:
in_channels (int): input tensor channels.
out_channels (int): output tensor channels.
hidden_channels (int): inner layers hidden channels.
kernel_size (int): conv1d filter kernel size.
dropout_p (float, optional): dropout rate. Defaults to 0.
"""
def __init__(self,
in_channels,
out_channels,
hidden_channels,
kernel_size,
dropout_p=0.):
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0):
super().__init__()
self.in_channels = in_channels
@ -288,14 +270,8 @@ class FeedForwardNetwork(nn.Module):
self.kernel_size = kernel_size
self.dropout_p = dropout_p
self.conv_1 = nn.Conv1d(in_channels,
hidden_channels,
kernel_size,
padding=kernel_size // 2)
self.conv_2 = nn.Conv1d(hidden_channels,
out_channels,
kernel_size,
padding=kernel_size // 2)
self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size, padding=kernel_size // 2)
self.dropout = nn.Dropout(dropout_p)
def forward(self, x, x_mask):
@ -308,34 +284,37 @@ class FeedForwardNetwork(nn.Module):
class RelativePositionTransformer(nn.Module):
"""Transformer with Relative Potional Encoding.
https://arxiv.org/abs/1803.02155
https://arxiv.org/abs/1803.02155
Args:
in_channels (int): number of channels of the input tensor.
out_chanels (int): number of channels of the output tensor.
hidden_channels (int): model hidden channels.
hidden_channels_ffn (int): hidden channels of FeedForwardNetwork.
num_heads (int): number of attention heads.
num_layers (int): number of transformer layers.
kernel_size (int, optional): kernel size of feed-forward inner layers. Defaults to 1.
dropout_p (float, optional): dropout rate for self-attention and feed-forward inner layers_per_stack. Defaults to 0.
rel_attn_window_size (int, optional): relation attention window size.
If 4, for each time step next and previous 4 time steps are attended.
If default, relative encoding is disabled and it is a regular transformer.
Defaults to None.
input_length (int, optional): input lenght to limit position encoding. Defaults to None.
Args:
in_channels (int): number of channels of the input tensor.
out_chanels (int): number of channels of the output tensor.
hidden_channels (int): model hidden channels.
hidden_channels_ffn (int): hidden channels of FeedForwardNetwork.
num_heads (int): number of attention heads.
num_layers (int): number of transformer layers.
kernel_size (int, optional): kernel size of feed-forward inner layers. Defaults to 1.
dropout_p (float, optional): dropout rate for self-attention and feed-forward inner layers_per_stack. Defaults to 0.
rel_attn_window_size (int, optional): relation attention window size.
If 4, for each time step next and previous 4 time steps are attended.
If default, relative encoding is disabled and it is a regular transformer.
Defaults to None.
input_length (int, optional): input lenght to limit position encoding. Defaults to None.
"""
def __init__(self,
in_channels,
out_channels,
hidden_channels,
hidden_channels_ffn,
num_heads,
num_layers,
kernel_size=1,
dropout_p=0.,
rel_attn_window_size=None,
input_length=None):
def __init__(
self,
in_channels,
out_channels,
hidden_channels,
hidden_channels_ffn,
num_heads,
num_layers,
kernel_size=1,
dropout_p=0.0,
rel_attn_window_size=None,
input_length=None,
):
super().__init__()
self.hidden_channels = hidden_channels
self.hidden_channels_ffn = hidden_channels_ffn
@ -359,7 +338,9 @@ class RelativePositionTransformer(nn.Module):
num_heads,
rel_attn_window_size=rel_attn_window_size,
dropout_p=dropout_p,
input_length=input_length))
input_length=input_length,
)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
if hidden_channels != out_channels and (idx + 1) == self.num_layers:
@ -368,15 +349,14 @@ class RelativePositionTransformer(nn.Module):
self.ffn_layers.append(
FeedForwardNetwork(
hidden_channels,
hidden_channels if
(idx + 1) != self.num_layers else out_channels,
hidden_channels if (idx + 1) != self.num_layers else out_channels,
hidden_channels_ffn,
kernel_size,
dropout_p=dropout_p))
dropout_p=dropout_p,
)
)
self.norm_layers_2.append(
LayerNorm(hidden_channels if (
idx + 1) != self.num_layers else out_channels))
self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels))
def forward(self, x, x_mask):
"""
@ -394,7 +374,7 @@ class RelativePositionTransformer(nn.Module):
y = self.ffn_layers[i](x, x_mask)
y = self.dropout(y)
if (i + 1) == self.num_layers and hasattr(self, 'proj'):
if (i + 1) == self.num_layers and hasattr(self, "proj"):
x = self.proj(x)
x = self.norm_layers_2[i](x + y)

View File

@ -34,26 +34,23 @@ class L1LossMasked(nn.Module):
"""
# mask: (batch, max_len, 1)
target.requires_grad = False
mask = sequence_mask(sequence_length=length,
max_len=target.size(1)).unsqueeze(2).float()
mask = sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
if self.seq_len_norm:
norm_w = mask / mask.sum(dim=1, keepdim=True)
out_weights = norm_w.div(target.shape[0] * target.shape[2])
mask = mask.expand_as(x)
loss = functional.l1_loss(x * mask,
target * mask,
reduction='none')
loss = functional.l1_loss(x * mask, target * mask, reduction="none")
loss = loss.mul(out_weights.to(loss.device)).sum()
else:
mask = mask.expand_as(x)
loss = functional.l1_loss(x * mask, target * mask, reduction='sum')
loss = functional.l1_loss(x * mask, target * mask, reduction="sum")
loss = loss / mask.sum()
return loss
class MSELossMasked(nn.Module):
def __init__(self, seq_len_norm):
super(MSELossMasked, self).__init__()
super().__init__()
self.seq_len_norm = seq_len_norm
def forward(self, x, target, length):
@ -76,27 +73,23 @@ class MSELossMasked(nn.Module):
"""
# mask: (batch, max_len, 1)
target.requires_grad = False
mask = sequence_mask(sequence_length=length,
max_len=target.size(1)).unsqueeze(2).float()
mask = sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
if self.seq_len_norm:
norm_w = mask / mask.sum(dim=1, keepdim=True)
out_weights = norm_w.div(target.shape[0] * target.shape[2])
mask = mask.expand_as(x)
loss = functional.mse_loss(x * mask,
target * mask,
reduction='none')
loss = functional.mse_loss(x * mask, target * mask, reduction="none")
loss = loss.mul(out_weights.to(loss.device)).sum()
else:
mask = mask.expand_as(x)
loss = functional.mse_loss(x * mask,
target * mask,
reduction='sum')
loss = functional.mse_loss(x * mask, target * mask, reduction="sum")
loss = loss / mask.sum()
return loss
class SSIMLoss(torch.nn.Module):
"""SSIM loss as explained here https://en.wikipedia.org/wiki/Structural_similarity"""
def __init__(self):
super().__init__()
self.loss_func = ssim
@ -115,9 +108,7 @@ class SSIMLoss(torch.nn.Module):
loss: An average loss value in range [0, 1] masked by the length.
"""
if length is not None:
m = sequence_mask(sequence_length=length,
max_len=y.size(1)).unsqueeze(2).float().to(
y_hat.device)
m = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float().to(y_hat.device)
y_hat, y = y_hat * m, y * m
return 1 - self.loss_func(y_hat.unsqueeze(1), y.unsqueeze(1))
@ -139,7 +130,7 @@ class AttentionEntropyLoss(nn.Module):
class BCELossMasked(nn.Module):
def __init__(self, pos_weight):
super(BCELossMasked, self).__init__()
super().__init__()
self.pos_weight = pos_weight
def forward(self, x, target, length):
@ -163,25 +154,20 @@ class BCELossMasked(nn.Module):
# mask: (batch, max_len, 1)
target.requires_grad = False
if length is not None:
mask = sequence_mask(sequence_length=length,
max_len=target.size(1)).float()
mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float()
x = x * mask
target = target * mask
num_items = mask.sum()
else:
num_items = torch.numel(x)
loss = functional.binary_cross_entropy_with_logits(
x,
target,
pos_weight=self.pos_weight,
reduction='sum')
loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum")
loss = loss / num_items
return loss
class DifferentailSpectralLoss(nn.Module):
"""Differential Spectral Loss
https://arxiv.org/ftp/arxiv/papers/1909/1909.10302.pdf"""
https://arxiv.org/ftp/arxiv/papers/1909/1909.10302.pdf"""
def __init__(self, loss_func):
super().__init__()
@ -200,12 +186,12 @@ class DifferentailSpectralLoss(nn.Module):
target_diff = target[:, 1:] - target[:, :-1]
if length is None:
return self.loss_func(x_diff, target_diff)
return self.loss_func(x_diff, target_diff, length-1)
return self.loss_func(x_diff, target_diff, length - 1)
class GuidedAttentionLoss(torch.nn.Module):
def __init__(self, sigma=0.4):
super(GuidedAttentionLoss, self).__init__()
super().__init__()
self.sigma = sigma
def _make_ga_masks(self, ilens, olens):
@ -214,8 +200,7 @@ class GuidedAttentionLoss(torch.nn.Module):
max_olen = max(olens)
ga_masks = torch.zeros((B, max_olen, max_ilen))
for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
ga_masks[idx, :olen, :ilen] = self._make_ga_mask(
ilen, olen, self.sigma)
ga_masks[idx, :olen, :ilen] = self._make_ga_mask(ilen, olen, self.sigma)
return ga_masks
def forward(self, att_ws, ilens, olens):
@ -229,8 +214,7 @@ class GuidedAttentionLoss(torch.nn.Module):
def _make_ga_mask(ilen, olen, sigma):
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen))
grid_x, grid_y = grid_x.float(), grid_y.float()
return 1.0 - torch.exp(-(grid_y / ilen - grid_x / olen)**2 /
(2 * (sigma**2)))
return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2)))
@staticmethod
def _make_masks(ilens, olens):
@ -249,18 +233,19 @@ class Huber(nn.Module):
length: B
"""
mask = sequence_mask(sequence_length=length, max_len=y.size(1)).float()
return torch.nn.functional.smooth_l1_loss(
x * mask, y * mask, reduction='sum') / mask.sum()
return torch.nn.functional.smooth_l1_loss(x * mask, y * mask, reduction="sum") / mask.sum()
########################
# MODEL LOSS LAYERS
########################
class TacotronLoss(torch.nn.Module):
"""Collection of Tacotron set-up based on provided config."""
def __init__(self, c, stopnet_pos_weight=10, ga_sigma=0.4):
super(TacotronLoss, self).__init__()
super().__init__()
self.stopnet_pos_weight = stopnet_pos_weight
self.ga_alpha = c.ga_alpha
self.decoder_diff_spec_alpha = c.decoder_diff_spec_alpha
@ -273,12 +258,9 @@ class TacotronLoss(torch.nn.Module):
# postnet and decoder loss
if c.loss_masking:
self.criterion = L1LossMasked(c.seq_len_norm) if c.model in [
"Tacotron"
] else MSELossMasked(c.seq_len_norm)
self.criterion = L1LossMasked(c.seq_len_norm) if c.model in ["Tacotron"] else MSELossMasked(c.seq_len_norm)
else:
self.criterion = nn.L1Loss() if c.model in ["Tacotron"
] else nn.MSELoss()
self.criterion = nn.L1Loss() if c.model in ["Tacotron"] else nn.MSELoss()
# guided attention loss
if c.ga_alpha > 0:
self.criterion_ga = GuidedAttentionLoss(sigma=ga_sigma)
@ -290,13 +272,23 @@ class TacotronLoss(torch.nn.Module):
self.criterion_ssim = SSIMLoss()
# stopnet loss
# pylint: disable=not-callable
self.criterion_st = BCELossMasked(
pos_weight=torch.tensor(stopnet_pos_weight)) if c.stopnet else None
def forward(self, postnet_output, decoder_output, mel_input, linear_input,
stopnet_output, stopnet_target, output_lens, decoder_b_output,
alignments, alignment_lens, alignments_backwards, input_lens):
self.criterion_st = BCELossMasked(pos_weight=torch.tensor(stopnet_pos_weight)) if c.stopnet else None
def forward(
self,
postnet_output,
decoder_output,
mel_input,
linear_input,
stopnet_output,
stopnet_target,
output_lens,
decoder_b_output,
alignments,
alignment_lens,
alignments_backwards,
input_lens,
):
# decoder outputs linear or mel spectrograms for Tacotron and Tacotron2
# the target should be set acccordingly
@ -309,85 +301,80 @@ class TacotronLoss(torch.nn.Module):
# decoder and postnet losses
if self.config.loss_masking:
if self.decoder_alpha > 0:
decoder_loss = self.criterion(decoder_output, mel_input,
output_lens)
decoder_loss = self.criterion(decoder_output, mel_input, output_lens)
if self.postnet_alpha > 0:
postnet_loss = self.criterion(postnet_output, postnet_target,
output_lens)
postnet_loss = self.criterion(postnet_output, postnet_target, output_lens)
else:
if self.decoder_alpha > 0:
decoder_loss = self.criterion(decoder_output, mel_input)
if self.postnet_alpha > 0:
postnet_loss = self.criterion(postnet_output, postnet_target)
loss = self.decoder_alpha * decoder_loss + self.postnet_alpha * postnet_loss
return_dict['decoder_loss'] = decoder_loss
return_dict['postnet_loss'] = postnet_loss
return_dict["decoder_loss"] = decoder_loss
return_dict["postnet_loss"] = postnet_loss
# stopnet loss
stop_loss = self.criterion_st(
stopnet_output, stopnet_target,
output_lens) if self.config.stopnet else torch.zeros(1)
stop_loss = (
self.criterion_st(stopnet_output, stopnet_target, output_lens) if self.config.stopnet else torch.zeros(1)
)
if not self.config.separate_stopnet and self.config.stopnet:
loss += stop_loss
return_dict['stopnet_loss'] = stop_loss
return_dict["stopnet_loss"] = stop_loss
# backward decoder loss (if enabled)
if self.config.bidirectional_decoder:
if self.config.loss_masking:
decoder_b_loss = self.criterion(
torch.flip(decoder_b_output, dims=(1, )), mel_input,
output_lens)
decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1,)), mel_input, output_lens)
else:
decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1, )), mel_input)
decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_b_output, dims=(1, )), decoder_output)
decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1,)), mel_input)
decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_b_output, dims=(1,)), decoder_output)
loss += self.decoder_alpha * (decoder_b_loss + decoder_c_loss)
return_dict['decoder_b_loss'] = decoder_b_loss
return_dict['decoder_c_loss'] = decoder_c_loss
return_dict["decoder_b_loss"] = decoder_b_loss
return_dict["decoder_c_loss"] = decoder_c_loss
# double decoder consistency loss (if enabled)
if self.config.double_decoder_consistency:
if self.config.loss_masking:
decoder_b_loss = self.criterion(decoder_b_output, mel_input,
output_lens)
decoder_b_loss = self.criterion(decoder_b_output, mel_input, output_lens)
else:
decoder_b_loss = self.criterion(decoder_b_output, mel_input)
# decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output)
attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards)
loss += self.decoder_alpha * (decoder_b_loss + attention_c_loss)
return_dict['decoder_coarse_loss'] = decoder_b_loss
return_dict['decoder_ddc_loss'] = attention_c_loss
return_dict["decoder_coarse_loss"] = decoder_b_loss
return_dict["decoder_ddc_loss"] = attention_c_loss
# guided attention loss (if enabled)
if self.config.ga_alpha > 0:
ga_loss = self.criterion_ga(alignments, input_lens, alignment_lens)
loss += ga_loss * self.ga_alpha
return_dict['ga_loss'] = ga_loss
return_dict["ga_loss"] = ga_loss
# decoder differential spectral loss
if self.config.decoder_diff_spec_alpha > 0:
decoder_diff_spec_loss = self.criterion_diff_spec(decoder_output, mel_input, output_lens)
loss += decoder_diff_spec_loss * self.decoder_diff_spec_alpha
return_dict['decoder_diff_spec_loss'] = decoder_diff_spec_loss
return_dict["decoder_diff_spec_loss"] = decoder_diff_spec_loss
# postnet differential spectral loss
if self.config.postnet_diff_spec_alpha > 0:
postnet_diff_spec_loss = self.criterion_diff_spec(postnet_output, postnet_target, output_lens)
loss += postnet_diff_spec_loss * self.postnet_diff_spec_alpha
return_dict['postnet_diff_spec_loss'] = postnet_diff_spec_loss
return_dict["postnet_diff_spec_loss"] = postnet_diff_spec_loss
# decoder ssim loss
if self.config.decoder_ssim_alpha > 0:
decoder_ssim_loss = self.criterion_ssim(decoder_output, mel_input, output_lens)
loss += decoder_ssim_loss * self.postnet_ssim_alpha
return_dict['decoder_ssim_loss'] = decoder_ssim_loss
return_dict["decoder_ssim_loss"] = decoder_ssim_loss
# postnet ssim loss
if self.config.postnet_ssim_alpha > 0:
postnet_ssim_loss = self.criterion_ssim(postnet_output, postnet_target, output_lens)
loss += postnet_ssim_loss * self.postnet_ssim_alpha
return_dict['postnet_ssim_loss'] = postnet_ssim_loss
return_dict["postnet_ssim_loss"] = postnet_ssim_loss
return_dict['loss'] = loss
return_dict["loss"] = loss
# check if any loss is NaN
for key, loss in return_dict.items():
@ -401,22 +388,18 @@ class GlowTTSLoss(torch.nn.Module):
super().__init__()
self.constant_factor = 0.5 * math.log(2 * math.pi)
def forward(self, z, means, scales, log_det, y_lengths, o_dur_log,
o_attn_dur, x_lengths):
def forward(self, z, means, scales, log_det, y_lengths, o_dur_log, o_attn_dur, x_lengths):
return_dict = {}
# flow loss - neg log likelihood
pz = torch.sum(scales) + 0.5 * torch.sum(
torch.exp(-2 * scales) * (z - means)**2)
log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (
torch.sum(y_lengths) * z.shape[1])
pz = torch.sum(scales) + 0.5 * torch.sum(torch.exp(-2 * scales) * (z - means) ** 2)
log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (torch.sum(y_lengths) * z.shape[1])
# duration loss - MSE
# loss_dur = torch.sum((o_dur_log - o_attn_dur)**2) / torch.sum(x_lengths)
# duration loss - huber loss
loss_dur = torch.nn.functional.smooth_l1_loss(
o_dur_log, o_attn_dur, reduction='sum') / torch.sum(x_lengths)
return_dict['loss'] = log_mle + loss_dur
return_dict['log_mle'] = log_mle
return_dict['loss_dur'] = loss_dur
loss_dur = torch.nn.functional.smooth_l1_loss(o_dur_log, o_attn_dur, reduction="sum") / torch.sum(x_lengths)
return_dict["loss"] = log_mle + loss_dur
return_dict["log_mle"] = log_mle
return_dict["loss_dur"] = loss_dur
# check if any loss is NaN
for key, loss in return_dict.items():
@ -441,7 +424,7 @@ class SpeedySpeechLoss(nn.Module):
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
huber_loss = self.huber(dur_output, dur_target, input_lens)
loss = self.l1_alpha * l1_loss + self.ssim_alpha * ssim_loss + self.huber_alpha * huber_loss
return {'loss': loss, 'loss_l1': l1_loss, 'loss_ssim': ssim_loss, 'loss_dur': huber_loss}
return {"loss": loss, "loss_l1": l1_loss, "loss_ssim": ssim_loss, "loss_dur": huber_loss}
def mse_loss_custom(x, y):
@ -452,26 +435,27 @@ def mse_loss_custom(x, y):
class MDNLoss(nn.Module):
"""Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf.
"""
"""Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf."""
def forward(self, logp, text_lengths, mel_lengths): # pylint: disable=no-self-use
'''
"""
Shapes:
mu: [B, D, T]
log_sigma: [B, D, T]
mel_spec: [B, D, T]
'''
"""
B, T_seq, T_mel = logp.shape
log_alpha = logp.new_ones(B, T_seq, T_mel)*(-1e4)
log_alpha = logp.new_ones(B, T_seq, T_mel) * (-1e4)
log_alpha[:, 0, 0] = logp[:, 0, 0]
for t in range(1, T_mel):
prev_step = torch.cat([log_alpha[:, :, t-1:t], functional.pad(log_alpha[:, :, t-1:t],
(0, 0, 1, -1), value=-1e4)], dim=-1)
prev_step = torch.cat(
[log_alpha[:, :, t - 1 : t], functional.pad(log_alpha[:, :, t - 1 : t], (0, 0, 1, -1), value=-1e4)],
dim=-1,
)
log_alpha[:, :, t] = torch.logsumexp(prev_step + 1e-4, dim=-1) + logp[:, :, t]
alpha_last = log_alpha[torch.arange(B), text_lengths-1, mel_lengths-1]
alpha_last = log_alpha[torch.arange(B), text_lengths - 1, mel_lengths - 1]
mdn_loss = -alpha_last.mean() / T_seq
return mdn_loss#, log_prob_matrix
return mdn_loss # , log_prob_matrix
class AlignTTSLoss(nn.Module):
@ -487,6 +471,7 @@ class AlignTTSLoss(nn.Module):
Args:
c (dict): TTS model configuration.
"""
def __init__(self, c):
super().__init__()
self.mdn_loss = MDNLoss()
@ -499,10 +484,10 @@ class AlignTTSLoss(nn.Module):
self.spec_loss_alpha = c.spec_loss_alpha
self.mdn_alpha = c.mdn_alpha
def forward(self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target,
input_lens, step, phase):
ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(
step)
def forward(
self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step, phase
):
ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(step)
spec_loss, ssim_loss, dur_loss, mdn_loss = 0, 0, 0, 0
if phase == 0:
mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens)
@ -521,11 +506,11 @@ class AlignTTSLoss(nn.Module):
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens)
loss = spec_loss_alpha * spec_loss + ssim_alpha * ssim_loss + dur_loss_alpha * dur_loss + mdn_alpha * mdn_loss
return {'loss': loss, 'loss_l1': spec_loss, 'loss_ssim': ssim_loss, 'loss_dur': dur_loss, 'mdn_loss': mdn_loss}
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
@staticmethod
def _set_alpha(step, alpha_settings):
'''Set the loss alpha wrt number of steps.
"""Set the loss alpha wrt number of steps.
Return the corresponding value if no schedule is set.
Example:
@ -536,7 +521,7 @@ class AlignTTSLoss(nn.Module):
Args:
step (int): number of training steps.
alpha_settings (int or list): constant alpha value or a list defining the schedule as explained above.
'''
"""
return_alpha = None
if isinstance(alpha_settings, list):
for key, alpha in alpha_settings:
@ -547,8 +532,7 @@ class AlignTTSLoss(nn.Module):
return return_alpha
def set_alphas(self, step):
'''Set the alpha values for all the loss functions
'''
"""Set the alpha values for all the loss functions"""
ssim_alpha = self._set_alpha(step, self.ssim_alpha)
dur_loss_alpha = self._set_alpha(step, self.dur_loss_alpha)
spec_loss_alpha = self._set_alpha(step, self.spec_loss_alpha)

View File

@ -14,20 +14,18 @@ class LocationLayer(nn.Module):
attention_n_filters (int, optional): number of filters in convolution. Defaults to 32.
attention_kernel_size (int, optional): kernel size of convolution filter. Defaults to 31.
"""
def __init__(self,
attention_dim,
attention_n_filters=32,
attention_kernel_size=31):
super(LocationLayer, self).__init__()
def __init__(self, attention_dim, attention_n_filters=32, attention_kernel_size=31):
super().__init__()
self.location_conv1d = nn.Conv1d(
in_channels=2,
out_channels=attention_n_filters,
kernel_size=attention_kernel_size,
stride=1,
padding=(attention_kernel_size - 1) // 2,
bias=False)
self.location_dense = Linear(
attention_n_filters, attention_dim, bias=False, init_gain='tanh')
bias=False,
)
self.location_dense = Linear(attention_n_filters, attention_dim, bias=False, init_gain="tanh")
def forward(self, attention_cat):
"""
@ -35,8 +33,7 @@ class LocationLayer(nn.Module):
attention_cat: [B, 2, C]
"""
processed_attention = self.location_conv1d(attention_cat)
processed_attention = self.location_dense(
processed_attention.transpose(1, 2))
processed_attention = self.location_dense(processed_attention.transpose(1, 2))
return processed_attention
@ -49,31 +46,31 @@ class GravesAttention(nn.Module):
query_dim (int): number of channels in query tensor.
K (int): number of Gaussian heads to be used for computing attention.
"""
COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi))
def __init__(self, query_dim, K):
super(GravesAttention, self).__init__()
super().__init__()
self._mask_value = 1e-8
self.K = K
# self.attention_alignment = 0.05
self.eps = 1e-5
self.J = None
self.N_a = nn.Sequential(
nn.Linear(query_dim, query_dim, bias=True),
nn.ReLU(),
nn.Linear(query_dim, 3*K, bias=True))
nn.Linear(query_dim, query_dim, bias=True), nn.ReLU(), nn.Linear(query_dim, 3 * K, bias=True)
)
self.attention_weights = None
self.mu_prev = None
self.init_layers()
def init_layers(self):
torch.nn.init.constant_(self.N_a[2].bias[(2*self.K):(3*self.K)], 1.) # bias mean
torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10) # bias std
torch.nn.init.constant_(self.N_a[2].bias[(2 * self.K) : (3 * self.K)], 1.0) # bias mean
torch.nn.init.constant_(self.N_a[2].bias[self.K : (2 * self.K)], 10) # bias std
def init_states(self, inputs):
if self.J is None or inputs.shape[1]+1 > self.J.shape[-1]:
self.J = torch.arange(0, inputs.shape[1]+2.0).to(inputs.device) + 0.5
if self.J is None or inputs.shape[1] + 1 > self.J.shape[-1]:
self.J = torch.arange(0, inputs.shape[1] + 2.0).to(inputs.device) + 0.5
self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device)
self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device)
@ -108,7 +105,7 @@ class GravesAttention(nn.Module):
mu_t = self.mu_prev + torch.nn.functional.softplus(k_t)
g_t = torch.softmax(g_t, dim=-1) + self.eps
j = self.J[:inputs.size(1)+1]
j = self.J[: inputs.size(1) + 1]
# attention weights
phi_t = g_t.unsqueeze(-1) * (1 / (1 + torch.sigmoid((mu_t.unsqueeze(-1) - j) / sig_t.unsqueeze(-1))))
@ -164,21 +161,29 @@ class OriginalAttention(nn.Module):
trans_agent (bool): enable/disable transition agent in the forward attention.
forward_attn_mask (int): enable/disable an explicit masking in forward attention. It is useful to set at especially inference time.
"""
# Pylint gets confused by PyTorch conventions here
#pylint: disable=attribute-defined-outside-init
def __init__(self, query_dim, embedding_dim, attention_dim,
location_attention, attention_location_n_filters,
attention_location_kernel_size, windowing, norm, forward_attn,
trans_agent, forward_attn_mask):
super(OriginalAttention, self).__init__()
self.query_layer = Linear(
query_dim, attention_dim, bias=False, init_gain='tanh')
self.inputs_layer = Linear(
embedding_dim, attention_dim, bias=False, init_gain='tanh')
# pylint: disable=attribute-defined-outside-init
def __init__(
self,
query_dim,
embedding_dim,
attention_dim,
location_attention,
attention_location_n_filters,
attention_location_kernel_size,
windowing,
norm,
forward_attn,
trans_agent,
forward_attn_mask,
):
super().__init__()
self.query_layer = Linear(query_dim, attention_dim, bias=False, init_gain="tanh")
self.inputs_layer = Linear(embedding_dim, attention_dim, bias=False, init_gain="tanh")
self.v = Linear(attention_dim, 1, bias=True)
if trans_agent:
self.ta = nn.Linear(
query_dim + embedding_dim, 1, bias=True)
self.ta = nn.Linear(query_dim + embedding_dim, 1, bias=True)
if location_attention:
self.location_layer = LocationLayer(
attention_dim,
@ -202,9 +207,7 @@ class OriginalAttention(nn.Module):
def init_forward_attn(self, inputs):
B = inputs.shape[0]
T = inputs.shape[1]
self.alpha = torch.cat(
[torch.ones([B, 1]),
torch.zeros([B, T])[:, :-1] + 1e-7], dim=1).to(inputs.device)
self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1] + 1e-7], dim=1).to(inputs.device)
self.u = (0.5 * torch.ones([B, 1])).to(inputs.device)
def init_location_attention(self, inputs):
@ -230,14 +233,10 @@ class OriginalAttention(nn.Module):
self.attention_weights_cum += alignments
def get_location_attention(self, query, processed_inputs):
attention_cat = torch.cat((self.attention_weights.unsqueeze(1),
self.attention_weights_cum.unsqueeze(1)),
dim=1)
attention_cat = torch.cat((self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1)
processed_query = self.query_layer(query.unsqueeze(1))
processed_attention_weights = self.location_layer(attention_cat)
energies = self.v(
torch.tanh(processed_query + processed_attention_weights +
processed_inputs))
energies = self.v(torch.tanh(processed_query + processed_attention_weights + processed_inputs))
energies = energies.squeeze(-1)
return energies, processed_query
@ -264,24 +263,17 @@ class OriginalAttention(nn.Module):
def apply_forward_attention(self, alignment):
# forward attention
fwd_shifted_alpha = F.pad(
self.alpha[:, :-1].clone().to(alignment.device), (1, 0, 0, 0))
fwd_shifted_alpha = F.pad(self.alpha[:, :-1].clone().to(alignment.device), (1, 0, 0, 0))
# compute transition potentials
alpha = ((1 - self.u) * self.alpha
+ self.u * fwd_shifted_alpha
+ 1e-8) * alignment
alpha = ((1 - self.u) * self.alpha + self.u * fwd_shifted_alpha + 1e-8) * alignment
# force incremental alignment
if not self.training and self.forward_attn_mask:
_, n = fwd_shifted_alpha.max(1)
val, _ = alpha.max(1)
for b in range(alignment.shape[0]):
alpha[b, n[b] + 3:] = 0
alpha[b, :(
n[b] - 1
)] = 0 # ignore all previous states to prevent repetition.
alpha[b,
(n[b] - 2
)] = 0.01 * val[b] # smoothing factor for the prev step
alpha[b, n[b] + 3 :] = 0
alpha[b, : (n[b] - 1)] = 0 # ignore all previous states to prevent repetition.
alpha[b, (n[b] - 2)] = 0.01 * val[b] # smoothing factor for the prev step
# renormalize attention weights
alpha = alpha / alpha.sum(dim=1, keepdim=True)
return alpha
@ -295,11 +287,9 @@ class OriginalAttention(nn.Module):
mask: [B, T_en]
"""
if self.location_attention:
attention, _ = self.get_location_attention(
query, processed_inputs)
attention, _ = self.get_location_attention(query, processed_inputs)
else:
attention, _ = self.get_attention(
query, processed_inputs)
attention, _ = self.get_attention(query, processed_inputs)
# apply masking
if mask is not None:
attention.data.masked_fill_(~mask, self._mask_value)
@ -311,9 +301,7 @@ class OriginalAttention(nn.Module):
if self.norm == "softmax":
alignment = torch.softmax(attention, dim=-1)
elif self.norm == "sigmoid":
alignment = torch.sigmoid(attention) / torch.sigmoid(
attention).sum(
dim=1, keepdim=True)
alignment = torch.sigmoid(attention) / torch.sigmoid(attention).sum(dim=1, keepdim=True)
else:
raise ValueError("Unknown value for attention norm type")
@ -367,19 +355,20 @@ class MonotonicDynamicConvolutionAttention(nn.Module):
alpha (float, optional): [description]. Defaults to 0.1 from the paper.
beta (float, optional): [description]. Defaults to 0.9 from the paper.
"""
def __init__(
self,
query_dim,
embedding_dim, # pylint: disable=unused-argument
attention_dim,
static_filter_dim,
static_kernel_size,
dynamic_filter_dim,
dynamic_kernel_size,
prior_filter_len=11,
alpha=0.1,
beta=0.9,
):
self,
query_dim,
embedding_dim, # pylint: disable=unused-argument
attention_dim,
static_filter_dim,
static_kernel_size,
dynamic_filter_dim,
dynamic_kernel_size,
prior_filter_len=11,
alpha=0.1,
beta=0.9,
):
super().__init__()
self._mask_value = 1e-8
self.dynamic_filter_dim = dynamic_filter_dim
@ -388,9 +377,7 @@ class MonotonicDynamicConvolutionAttention(nn.Module):
self.attention_weights = None
# setup key and query layers
self.query_layer = nn.Linear(query_dim, attention_dim)
self.key_layer = nn.Linear(
attention_dim, dynamic_filter_dim * dynamic_kernel_size, bias=False
)
self.key_layer = nn.Linear(attention_dim, dynamic_filter_dim * dynamic_kernel_size, bias=False)
self.static_filter_conv = nn.Conv1d(
1,
static_filter_dim,
@ -402,8 +389,7 @@ class MonotonicDynamicConvolutionAttention(nn.Module):
self.dynamic_filter_layer = nn.Linear(dynamic_filter_dim, attention_dim)
self.v = nn.Linear(attention_dim, 1, bias=False)
prior = betabinom.pmf(range(prior_filter_len), prior_filter_len - 1,
alpha, beta)
prior = betabinom.pmf(range(prior_filter_len), prior_filter_len - 1, alpha, beta)
self.register_buffer("prior", torch.FloatTensor(prior).flip(0))
# pylint: disable=unused-argument
@ -416,8 +402,8 @@ class MonotonicDynamicConvolutionAttention(nn.Module):
"""
# compute prior filters
prior_filter = F.conv1d(
F.pad(self.attention_weights.unsqueeze(1),
(self.prior_filter_len - 1, 0)), self.prior.view(1, 1, -1))
F.pad(self.attention_weights.unsqueeze(1), (self.prior_filter_len - 1, 0)), self.prior.view(1, 1, -1)
)
prior_filter = torch.log(prior_filter.clamp_min_(1e-6)).squeeze(1)
G = self.key_layer(torch.tanh(self.query_layer(query)))
# compute dynamic filters
@ -430,10 +416,12 @@ class MonotonicDynamicConvolutionAttention(nn.Module):
dynamic_filter = dynamic_filter.view(query.size(0), self.dynamic_filter_dim, -1).transpose(1, 2)
# compute static filters
static_filter = self.static_filter_conv(self.attention_weights.unsqueeze(1)).transpose(1, 2)
alignment = self.v(
torch.tanh(
self.static_filter_layer(static_filter) +
self.dynamic_filter_layer(dynamic_filter))).squeeze(-1) + prior_filter
alignment = (
self.v(
torch.tanh(self.static_filter_layer(static_filter) + self.dynamic_filter_layer(dynamic_filter))
).squeeze(-1)
+ prior_filter
)
# compute attention weights
attention_weights = F.softmax(alignment, dim=-1)
# apply masking
@ -451,33 +439,52 @@ class MonotonicDynamicConvolutionAttention(nn.Module):
B = inputs.size(0)
T = inputs.size(1)
self.attention_weights = torch.zeros([B, T], device=inputs.device)
self.attention_weights[:, 0] = 1.
self.attention_weights[:, 0] = 1.0
def init_attn(attn_type, query_dim, embedding_dim, attention_dim,
location_attention, attention_location_n_filters,
attention_location_kernel_size, windowing, norm, forward_attn,
trans_agent, forward_attn_mask, attn_K):
def init_attn(
attn_type,
query_dim,
embedding_dim,
attention_dim,
location_attention,
attention_location_n_filters,
attention_location_kernel_size,
windowing,
norm,
forward_attn,
trans_agent,
forward_attn_mask,
attn_K,
):
if attn_type == "original":
return OriginalAttention(query_dim, embedding_dim, attention_dim,
location_attention,
attention_location_n_filters,
attention_location_kernel_size, windowing,
norm, forward_attn, trans_agent,
forward_attn_mask)
return OriginalAttention(
query_dim,
embedding_dim,
attention_dim,
location_attention,
attention_location_n_filters,
attention_location_kernel_size,
windowing,
norm,
forward_attn,
trans_agent,
forward_attn_mask,
)
if attn_type == "graves":
return GravesAttention(query_dim, attn_K)
if attn_type == "dynamic_convolution":
return MonotonicDynamicConvolutionAttention(query_dim,
embedding_dim,
attention_dim,
static_filter_dim=8,
static_kernel_size=21,
dynamic_filter_dim=8,
dynamic_kernel_size=21,
prior_filter_len=11,
alpha=0.1,
beta=0.9)
return MonotonicDynamicConvolutionAttention(
query_dim,
embedding_dim,
attention_dim,
static_filter_dim=8,
static_kernel_size=21,
dynamic_filter_dim=8,
dynamic_kernel_size=21,
prior_filter_len=11,
alpha=0.1,
beta=0.9,
)
raise RuntimeError(
" [!] Given Attention Type '{attn_type}' is not exist.")
raise RuntimeError(" [!] Given Attention Type '{attn_type}' is not exist.")

View File

@ -12,20 +12,14 @@ class Linear(nn.Module):
bias (bool, optional): enable/disable bias in the layer. Defaults to True.
init_gain (str, optional): method to compute the gain in the weight initializtion based on the nonlinear activation used afterwards. Defaults to 'linear'.
"""
def __init__(self,
in_features,
out_features,
bias=True,
init_gain='linear'):
super(Linear, self).__init__()
self.linear_layer = torch.nn.Linear(
in_features, out_features, bias=bias)
def __init__(self, in_features, out_features, bias=True, init_gain="linear"):
super().__init__()
self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias)
self._init_w(init_gain)
def _init_w(self, init_gain):
torch.nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(init_gain))
torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain))
def forward(self, x):
return self.linear_layer(x)
@ -42,21 +36,15 @@ class LinearBN(nn.Module):
bias (bool, optional): enable/disable bias in the linear layer. Defaults to True.
init_gain (str, optional): method to set the gain for weight initialization. Defaults to 'linear'.
"""
def __init__(self,
in_features,
out_features,
bias=True,
init_gain='linear'):
super(LinearBN, self).__init__()
self.linear_layer = torch.nn.Linear(
in_features, out_features, bias=bias)
def __init__(self, in_features, out_features, bias=True, init_gain="linear"):
super().__init__()
self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias)
self.batch_normalization = nn.BatchNorm1d(out_features, momentum=0.1, eps=1e-5)
self._init_w(init_gain)
def _init_w(self, init_gain):
torch.nn.init.xavier_uniform_(
self.linear_layer.weight,
gain=torch.nn.init.calculate_gain(init_gain))
torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain))
def forward(self, x):
"""
@ -96,27 +84,21 @@ class Prenet(nn.Module):
Defaults to [256, 256].
bias (bool, optional): enable/disable bias in prenet linear layers. Defaults to True.
"""
# pylint: disable=dangerous-default-value
def __init__(self,
in_features,
prenet_type="original",
prenet_dropout=True,
out_features=[256, 256],
bias=True):
super(Prenet, self).__init__()
def __init__(self, in_features, prenet_type="original", prenet_dropout=True, out_features=[256, 256], bias=True):
super().__init__()
self.prenet_type = prenet_type
self.prenet_dropout = prenet_dropout
in_features = [in_features] + out_features[:-1]
if prenet_type == "bn":
self.linear_layers = nn.ModuleList([
LinearBN(in_size, out_size, bias=bias)
for (in_size, out_size) in zip(in_features, out_features)
])
self.linear_layers = nn.ModuleList(
[LinearBN(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)]
)
elif prenet_type == "original":
self.linear_layers = nn.ModuleList([
Linear(in_size, out_size, bias=bias)
for (in_size, out_size) in zip(in_features, out_features)
])
self.linear_layers = nn.ModuleList(
[Linear(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)]
)
def forward(self, x):
for linear in self.linear_layers:

View File

@ -11,8 +11,7 @@ class GST(nn.Module):
def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, speaker_embedding_dim=None):
super().__init__()
self.encoder = ReferenceEncoder(num_mel, gst_embedding_dim)
self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens,
gst_embedding_dim, speaker_embedding_dim)
self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, speaker_embedding_dim)
def forward(self, inputs, speaker_embedding=None):
enc_out = self.encoder(inputs)
@ -39,24 +38,17 @@ class ReferenceEncoder(nn.Module):
num_layers = len(filters) - 1
convs = [
nn.Conv2d(
in_channels=filters[i],
out_channels=filters[i + 1],
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1)) for i in range(num_layers)
in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
)
for i in range(num_layers)
]
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList([
nn.BatchNorm2d(num_features=filter_size)
for filter_size in filters[1:]
])
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]])
post_conv_height = self.calculate_post_conv_height(
num_mel, 3, 2, 1, num_layers)
post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 1, num_layers)
self.recurrence = nn.GRU(
input_size=filters[-1] * post_conv_height,
hidden_size=embedding_dim // 2,
batch_first=True)
input_size=filters[-1] * post_conv_height, hidden_size=embedding_dim // 2, batch_first=True
)
def forward(self, inputs):
batch_size = inputs.size(0)
@ -81,8 +73,7 @@ class ReferenceEncoder(nn.Module):
return out.squeeze(0)
@staticmethod
def calculate_post_conv_height(height, kernel_size, stride, pad,
n_convs):
def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs):
"""Height of spec after n convolutions with fixed kernel/stride/pad."""
for _ in range(n_convs):
height = (height - kernel_size + 2 * pad) // stride + 1
@ -92,8 +83,7 @@ class ReferenceEncoder(nn.Module):
class StyleTokenLayer(nn.Module):
"""NN Module attending to style tokens based on prosody encodings."""
def __init__(self, num_heads, num_style_tokens,
embedding_dim, speaker_embedding_dim=None):
def __init__(self, num_heads, num_style_tokens, embedding_dim, speaker_embedding_dim=None):
super().__init__()
self.query_dim = embedding_dim // 2
@ -102,35 +92,31 @@ class StyleTokenLayer(nn.Module):
self.query_dim += speaker_embedding_dim
self.key_dim = embedding_dim // num_heads
self.style_tokens = nn.Parameter(
torch.FloatTensor(num_style_tokens, self.key_dim))
self.style_tokens = nn.Parameter(torch.FloatTensor(num_style_tokens, self.key_dim))
nn.init.normal_(self.style_tokens, mean=0, std=0.5)
self.attention = MultiHeadAttention(
query_dim=self.query_dim,
key_dim=self.key_dim,
num_units=embedding_dim,
num_heads=num_heads)
query_dim=self.query_dim, key_dim=self.key_dim, num_units=embedding_dim, num_heads=num_heads
)
def forward(self, inputs):
batch_size = inputs.size(0)
prosody_encoding = inputs.unsqueeze(1)
# prosody_encoding: 3D tensor [batch_size, 1, encoding_size==128]
tokens = torch.tanh(self.style_tokens) \
.unsqueeze(0) \
.expand(batch_size, -1, -1)
tokens = torch.tanh(self.style_tokens).unsqueeze(0).expand(batch_size, -1, -1)
# tokens: 3D tensor [batch_size, num tokens, token embedding size]
style_embed = self.attention(prosody_encoding, tokens)
return style_embed
class MultiHeadAttention(nn.Module):
'''
"""
input:
query --- [N, T_q, query_dim]
key --- [N, T_k, key_dim]
output:
out --- [N, T_q, num_units]
'''
"""
def __init__(self, query_dim, key_dim, num_units, num_heads):
@ -139,12 +125,9 @@ class MultiHeadAttention(nn.Module):
self.num_heads = num_heads
self.key_dim = key_dim
self.W_query = nn.Linear(
in_features=query_dim, out_features=num_units, bias=False)
self.W_key = nn.Linear(
in_features=key_dim, out_features=num_units, bias=False)
self.W_value = nn.Linear(
in_features=key_dim, out_features=num_units, bias=False)
self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
def forward(self, query, key):
queries = self.W_query(query) # [N, T_q, num_units]
@ -152,25 +135,17 @@ class MultiHeadAttention(nn.Module):
values = self.W_value(key)
split_size = self.num_units // self.num_heads
queries = torch.stack(
torch.split(queries, split_size, dim=2),
dim=0) # [h, N, T_q, num_units/h]
keys = torch.stack(
torch.split(keys, split_size, dim=2),
dim=0) # [h, N, T_k, num_units/h]
values = torch.stack(
torch.split(values, split_size, dim=2),
dim=0) # [h, N, T_k, num_units/h]
queries = torch.stack(torch.split(queries, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h]
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
# score = softmax(QK^T / (d_k ** 0.5))
scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k]
scores = scores / (self.key_dim**0.5)
scores = scores / (self.key_dim ** 0.5)
scores = F.softmax(scores, dim=3)
# out = score * V
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
out = torch.cat(
torch.split(out, 1, dim=0),
dim=3).squeeze(0) # [N, T_q, num_units]
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
return out

View File

@ -23,24 +23,14 @@ class BatchNormConv1d(nn.Module):
- output: (B, D)
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
activation=None):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, activation=None):
super(BatchNormConv1d, self).__init__()
super().__init__()
self.padding = padding
self.padder = nn.ConstantPad1d(padding, 0)
self.conv1d = nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
bias=False)
in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0, bias=False
)
# Following tensorflow's default parameters
self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
self.activation = activation
@ -48,15 +38,14 @@ class BatchNormConv1d(nn.Module):
def init_layers(self):
if isinstance(self.activation, torch.nn.ReLU):
w_gain = 'relu'
w_gain = "relu"
elif isinstance(self.activation, torch.nn.Tanh):
w_gain = 'tanh'
w_gain = "tanh"
elif self.activation is None:
w_gain = 'linear'
w_gain = "linear"
else:
raise RuntimeError('Unknown activation function')
torch.nn.init.xavier_uniform_(
self.conv1d.weight, gain=torch.nn.init.calculate_gain(w_gain))
raise RuntimeError("Unknown activation function")
torch.nn.init.xavier_uniform_(self.conv1d.weight, gain=torch.nn.init.calculate_gain(w_gain))
def forward(self, x):
x = self.padder(x)
@ -81,7 +70,7 @@ class Highway(nn.Module):
# TODO: Try GLU layer
def __init__(self, in_features, out_feature):
super(Highway, self).__init__()
super().__init__()
self.H = nn.Linear(in_features, out_feature)
self.H.bias.data.zero_()
self.T = nn.Linear(in_features, out_feature)
@ -91,10 +80,8 @@ class Highway(nn.Module):
# self.init_layers()
def init_layers(self):
torch.nn.init.xavier_uniform_(
self.H.weight, gain=torch.nn.init.calculate_gain('relu'))
torch.nn.init.xavier_uniform_(
self.T.weight, gain=torch.nn.init.calculate_gain('sigmoid'))
torch.nn.init.xavier_uniform_(self.H.weight, gain=torch.nn.init.calculate_gain("relu"))
torch.nn.init.xavier_uniform_(self.T.weight, gain=torch.nn.init.calculate_gain("sigmoid"))
def forward(self, inputs):
H = self.relu(self.H(inputs))
@ -104,30 +91,33 @@ class Highway(nn.Module):
class CBHG(nn.Module):
"""CBHG module: a recurrent neural network composed of:
- 1-d convolution banks
- Highway networks + residual connections
- Bidirectional gated recurrent units
- 1-d convolution banks
- Highway networks + residual connections
- Bidirectional gated recurrent units
Args:
in_features (int): sample size
K (int): max filter size in conv bank
projections (list): conv channel sizes for conv projections
num_highways (int): number of highways layers
Args:
in_features (int): sample size
K (int): max filter size in conv bank
projections (list): conv channel sizes for conv projections
num_highways (int): number of highways layers
Shapes:
- input: (B, C, T_in)
- output: (B, T_in, C*2)
Shapes:
- input: (B, C, T_in)
- output: (B, T_in, C*2)
"""
#pylint: disable=dangerous-default-value
def __init__(self,
in_features,
K=16,
conv_bank_features=128,
conv_projections=[128, 128],
highway_features=128,
gru_features=128,
num_highways=4):
super(CBHG, self).__init__()
# pylint: disable=dangerous-default-value
def __init__(
self,
in_features,
K=16,
conv_bank_features=128,
conv_projections=[128, 128],
highway_features=128,
gru_features=128,
num_highways=4,
):
super().__init__()
self.in_features = in_features
self.conv_bank_features = conv_bank_features
self.highway_features = highway_features
@ -136,14 +126,19 @@ class CBHG(nn.Module):
self.relu = nn.ReLU()
# list of conv1d bank with filter size k=1...K
# TODO: try dilational layers instead
self.conv1d_banks = nn.ModuleList([
BatchNormConv1d(in_features,
conv_bank_features,
kernel_size=k,
stride=1,
padding=[(k - 1) // 2, k // 2],
activation=self.relu) for k in range(1, K + 1)
])
self.conv1d_banks = nn.ModuleList(
[
BatchNormConv1d(
in_features,
conv_bank_features,
kernel_size=k,
stride=1,
padding=[(k - 1) // 2, k // 2],
activation=self.relu,
)
for k in range(1, K + 1)
]
)
# max pooling of conv bank, with padding
# TODO: try average pooling OR larger kernel size
out_features = [K * conv_bank_features] + conv_projections[:-1]
@ -151,31 +146,16 @@ class CBHG(nn.Module):
activations += [None]
# setup conv1d projection layers
layer_set = []
for (in_size, out_size, ac) in zip(out_features, conv_projections,
activations):
layer = BatchNormConv1d(in_size,
out_size,
kernel_size=3,
stride=1,
padding=[1, 1],
activation=ac)
for (in_size, out_size, ac) in zip(out_features, conv_projections, activations):
layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1, padding=[1, 1], activation=ac)
layer_set.append(layer)
self.conv1d_projections = nn.ModuleList(layer_set)
# setup Highway layers
if self.highway_features != conv_projections[-1]:
self.pre_highway = nn.Linear(conv_projections[-1],
highway_features,
bias=False)
self.highways = nn.ModuleList([
Highway(highway_features, highway_features)
for _ in range(num_highways)
])
self.pre_highway = nn.Linear(conv_projections[-1], highway_features, bias=False)
self.highways = nn.ModuleList([Highway(highway_features, highway_features) for _ in range(num_highways)])
# bi-directional GPU layer
self.gru = nn.GRU(gru_features,
gru_features,
1,
batch_first=True,
bidirectional=True)
self.gru = nn.GRU(gru_features, gru_features, 1, batch_first=True, bidirectional=True)
def forward(self, inputs):
# (B, in_features, T_in)
@ -210,7 +190,7 @@ class EncoderCBHG(nn.Module):
r"""CBHG module with Encoder specific arguments"""
def __init__(self):
super(EncoderCBHG, self).__init__()
super().__init__()
self.cbhg = CBHG(
128,
K=16,
@ -218,7 +198,8 @@ class EncoderCBHG(nn.Module):
conv_projections=[128, 128],
highway_features=128,
gru_features=128,
num_highways=4)
num_highways=4,
)
def forward(self, x):
return self.cbhg(x)
@ -235,7 +216,7 @@ class Encoder(nn.Module):
"""
def __init__(self, in_features):
super(Encoder, self).__init__()
super().__init__()
self.prenet = Prenet(in_features, out_features=[256, 128])
self.cbhg = EncoderCBHG()
@ -248,7 +229,7 @@ class Encoder(nn.Module):
class PostCBHG(nn.Module):
def __init__(self, mel_dim):
super(PostCBHG, self).__init__()
super().__init__()
self.cbhg = CBHG(
mel_dim,
K=8,
@ -256,7 +237,8 @@ class PostCBHG(nn.Module):
conv_projections=[256, mel_dim],
highway_features=128,
gru_features=128,
num_highways=4)
num_highways=4,
)
def forward(self, x):
return self.cbhg(x)
@ -289,11 +271,25 @@ class Decoder(nn.Module):
# Pylint gets confused by PyTorch conventions here
# pylint: disable=attribute-defined-outside-init
def __init__(self, in_channels, frame_channels, r, memory_size, attn_type, attn_windowing,
attn_norm, prenet_type, prenet_dropout, forward_attn,
trans_agent, forward_attn_mask, location_attn, attn_K,
separate_stopnet):
super(Decoder, self).__init__()
def __init__(
self,
in_channels,
frame_channels,
r,
memory_size,
attn_type,
attn_windowing,
attn_norm,
prenet_type,
prenet_dropout,
forward_attn,
trans_agent,
forward_attn_mask,
location_attn,
attn_K,
separate_stopnet,
):
super().__init__()
self.r_init = r
self.r = r
self.in_channels = in_channels
@ -305,33 +301,30 @@ class Decoder(nn.Module):
self.query_dim = 256
# memory -> |Prenet| -> processed_memory
prenet_dim = frame_channels * self.memory_size if self.use_memory_queue else frame_channels
self.prenet = Prenet(
prenet_dim,
prenet_type,
prenet_dropout,
out_features=[256, 128])
self.prenet = Prenet(prenet_dim, prenet_type, prenet_dropout, out_features=[256, 128])
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
# attention_rnn generates queries for the attention mechanism
self.attention_rnn = nn.GRUCell(in_channels + 128, self.query_dim)
self.attention = init_attn(attn_type=attn_type,
query_dim=self.query_dim,
embedding_dim=in_channels,
attention_dim=128,
location_attention=location_attn,
attention_location_n_filters=32,
attention_location_kernel_size=31,
windowing=attn_windowing,
norm=attn_norm,
forward_attn=forward_attn,
trans_agent=trans_agent,
forward_attn_mask=forward_attn_mask,
attn_K=attn_K)
self.attention = init_attn(
attn_type=attn_type,
query_dim=self.query_dim,
embedding_dim=in_channels,
attention_dim=128,
location_attention=location_attn,
attention_location_n_filters=32,
attention_location_kernel_size=31,
windowing=attn_windowing,
norm=attn_norm,
forward_attn=forward_attn,
trans_agent=trans_agent,
forward_attn_mask=forward_attn_mask,
attn_K=attn_K,
)
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256 + in_channels, 256)
# decoder_RNN_input -> |RNN| -> RNN_state
self.decoder_rnns = nn.ModuleList(
[nn.GRUCell(256, 256) for _ in range(2)])
self.decoder_rnns = nn.ModuleList([nn.GRUCell(256, 256) for _ in range(2)])
# RNN_state -> |Linear| -> mel_spec
self.proj_to_mel = nn.Linear(256, frame_channels * self.r_init)
# learn init values instead of zero init.
@ -364,8 +357,7 @@ class Decoder(nn.Module):
# decoder states
self.attention_rnn_hidden = torch.zeros(1, device=inputs.device).repeat(B, 256)
self.decoder_rnn_hiddens = [
torch.zeros(1, device=inputs.device).repeat(B, 256)
for idx in range(len(self.decoder_rnns))
torch.zeros(1, device=inputs.device).repeat(B, 256) for idx in range(len(self.decoder_rnns))
]
self.context_vec = inputs.data.new(B, self.in_channels).zero_()
# cache attention inputs
@ -376,8 +368,7 @@ class Decoder(nn.Module):
attentions = torch.stack(attentions).transpose(0, 1)
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
outputs = outputs.view(
outputs.size(0), -1, self.frame_channels)
outputs = outputs.view(outputs.size(0), -1, self.frame_channels)
outputs = outputs.transpose(1, 2)
return outputs, attentions, stop_tokens
@ -386,18 +377,15 @@ class Decoder(nn.Module):
processed_memory = self.prenet(self.memory_input)
# Attention RNN
self.attention_rnn_hidden = self.attention_rnn(
torch.cat((processed_memory, self.context_vec), -1),
self.attention_rnn_hidden)
self.context_vec = self.attention(
self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
torch.cat((processed_memory, self.context_vec), -1), self.attention_rnn_hidden
)
self.context_vec = self.attention(self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(
torch.cat((self.attention_rnn_hidden, self.context_vec), -1))
decoder_input = self.project_to_decoder_in(torch.cat((self.attention_rnn_hidden, self.context_vec), -1))
# Pass through the decoder RNNs
for idx in range(len(self.decoder_rnns)):
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
decoder_input, self.decoder_rnn_hiddens[idx])
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](decoder_input, self.decoder_rnn_hiddens[idx])
# Residual connection
decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input
decoder_output = decoder_input
@ -418,17 +406,17 @@ class Decoder(nn.Module):
if self.use_memory_queue:
if self.memory_size > self.r:
# memory queue size is larger than number of frames per decoder iter
self.memory_input = torch.cat([
new_memory, self.memory_input[:, :(
self.memory_size - self.r) * self.frame_channels].clone()
], dim=-1)
self.memory_input = torch.cat(
[new_memory, self.memory_input[:, : (self.memory_size - self.r) * self.frame_channels].clone()],
dim=-1,
)
else:
# memory queue size smaller than number of frames per decoder iter
self.memory_input = new_memory[:, :self.memory_size * self.frame_channels]
self.memory_input = new_memory[:, : self.memory_size * self.frame_channels]
else:
# use only the last frame prediction
# assert new_memory.shape[-1] == self.r * self.frame_channels
self.memory_input = new_memory[:, self.frame_channels * (self.r - 1):]
self.memory_input = new_memory[:, self.frame_channels * (self.r - 1) :]
def forward(self, inputs, memory, mask):
"""
@ -487,8 +475,7 @@ class Decoder(nn.Module):
attentions += [attention]
stop_tokens += [stop_token]
t += 1
if t > inputs.shape[1] / 4 and (stop_token > 0.6
or attention[:, -1].item() > 0.6):
if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6):
break
if t > self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")
@ -503,11 +490,10 @@ class StopNet(nn.Module):
"""
def __init__(self, in_features):
super(StopNet, self).__init__()
super().__init__()
self.dropout = nn.Dropout(0.1)
self.linear = nn.Linear(in_features, 1)
torch.nn.init.xavier_uniform_(
self.linear.weight, gain=torch.nn.init.calculate_gain('linear'))
torch.nn.init.xavier_uniform_(self.linear.weight, gain=torch.nn.init.calculate_gain("linear"))
def forward(self, inputs):
outputs = self.dropout(inputs)

View File

@ -5,8 +5,8 @@ from .common_layers import Prenet, Linear
from .attentions import init_attn
# NOTE: linter has a problem with the current TF release
#pylint: disable=no-value-for-parameter
#pylint: disable=unexpected-keyword-arg
# pylint: disable=no-value-for-parameter
# pylint: disable=unexpected-keyword-arg
class ConvBNBlock(nn.Module):
r"""Convolutions with Batch Normalization and non-linear activation.
@ -20,19 +20,17 @@ class ConvBNBlock(nn.Module):
- input: (B, C_in, T)
- output: (B, C_out, T)
"""
def __init__(self, in_channels, out_channels, kernel_size, activation=None):
super(ConvBNBlock, self).__init__()
super().__init__()
assert (kernel_size - 1) % 2 == 0
padding = (kernel_size - 1) // 2
self.convolution1d = nn.Conv1d(in_channels,
out_channels,
kernel_size,
padding=padding)
self.convolution1d = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding)
self.batch_normalization = nn.BatchNorm1d(out_channels, momentum=0.1, eps=1e-5)
self.dropout = nn.Dropout(p=0.5)
if activation == 'relu':
if activation == "relu":
self.activation = nn.ReLU()
elif activation == 'tanh':
elif activation == "tanh":
self.activation = nn.Tanh()
else:
self.activation = nn.Identity()
@ -55,16 +53,14 @@ class Postnet(nn.Module):
- input: (B, C_in, T)
- output: (B, C_in, T)
"""
def __init__(self, in_out_channels, num_convs=5):
super(Postnet, self).__init__()
super().__init__()
self.convolutions = nn.ModuleList()
self.convolutions.append(
ConvBNBlock(in_out_channels, 512, kernel_size=5, activation='tanh'))
self.convolutions.append(ConvBNBlock(in_out_channels, 512, kernel_size=5, activation="tanh"))
for _ in range(1, num_convs - 1):
self.convolutions.append(
ConvBNBlock(512, 512, kernel_size=5, activation='tanh'))
self.convolutions.append(
ConvBNBlock(512, in_out_channels, kernel_size=5, activation=None))
self.convolutions.append(ConvBNBlock(512, 512, kernel_size=5, activation="tanh"))
self.convolutions.append(ConvBNBlock(512, in_out_channels, kernel_size=5, activation=None))
def forward(self, x):
o = x
@ -83,18 +79,15 @@ class Encoder(nn.Module):
- input: (B, C_in, T)
- output: (B, C_in, T)
"""
def __init__(self, in_out_channels=512):
super(Encoder, self).__init__()
super().__init__()
self.convolutions = nn.ModuleList()
for _ in range(3):
self.convolutions.append(
ConvBNBlock(in_out_channels, in_out_channels, 5, 'relu'))
self.lstm = nn.LSTM(in_out_channels,
int(in_out_channels / 2),
num_layers=1,
batch_first=True,
bias=True,
bidirectional=True)
self.convolutions.append(ConvBNBlock(in_out_channels, in_out_channels, 5, "relu"))
self.lstm = nn.LSTM(
in_out_channels, int(in_out_channels / 2), num_layers=1, batch_first=True, bias=True, bidirectional=True
)
self.rnn_state = None
def forward(self, x, input_lengths):
@ -102,9 +95,7 @@ class Encoder(nn.Module):
for layer in self.convolutions:
o = layer(o)
o = o.transpose(1, 2)
o = nn.utils.rnn.pack_padded_sequence(o,
input_lengths.cpu(),
batch_first=True)
o = nn.utils.rnn.pack_padded_sequence(o, input_lengths.cpu(), batch_first=True)
self.lstm.flatten_parameters()
o, _ = self.lstm(o)
o, _ = nn.utils.rnn.pad_packed_sequence(o, batch_first=True)
@ -143,12 +134,27 @@ class Decoder(nn.Module):
attn_K (int): number of attention heads for GravesAttention.
separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow.
"""
# Pylint gets confused by PyTorch conventions here
#pylint: disable=attribute-defined-outside-init
def __init__(self, in_channels, frame_channels, r, attn_type, attn_win, attn_norm,
prenet_type, prenet_dropout, forward_attn, trans_agent,
forward_attn_mask, location_attn, attn_K, separate_stopnet):
super(Decoder, self).__init__()
# pylint: disable=attribute-defined-outside-init
def __init__(
self,
in_channels,
frame_channels,
r,
attn_type,
attn_win,
attn_norm,
prenet_type,
prenet_dropout,
forward_attn,
trans_agent,
forward_attn_mask,
location_attn,
attn_K,
separate_stopnet,
):
super().__init__()
self.frame_channels = frame_channels
self.r_init = r
self.r = r
@ -167,43 +173,36 @@ class Decoder(nn.Module):
# memory -> |Prenet| -> processed_memory
prenet_dim = self.frame_channels
self.prenet = Prenet(prenet_dim,
prenet_type,
prenet_dropout,
out_features=[self.prenet_dim, self.prenet_dim],
bias=False)
self.prenet = Prenet(
prenet_dim, prenet_type, prenet_dropout, out_features=[self.prenet_dim, self.prenet_dim], bias=False
)
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_channels,
self.query_dim,
bias=True)
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_channels, self.query_dim, bias=True)
self.attention = init_attn(attn_type=attn_type,
query_dim=self.query_dim,
embedding_dim=in_channels,
attention_dim=128,
location_attention=location_attn,
attention_location_n_filters=32,
attention_location_kernel_size=31,
windowing=attn_win,
norm=attn_norm,
forward_attn=forward_attn,
trans_agent=trans_agent,
forward_attn_mask=forward_attn_mask,
attn_K=attn_K)
self.attention = init_attn(
attn_type=attn_type,
query_dim=self.query_dim,
embedding_dim=in_channels,
attention_dim=128,
location_attention=location_attn,
attention_location_n_filters=32,
attention_location_kernel_size=31,
windowing=attn_win,
norm=attn_norm,
forward_attn=forward_attn,
trans_agent=trans_agent,
forward_attn_mask=forward_attn_mask,
attn_K=attn_K,
)
self.decoder_rnn = nn.LSTMCell(self.query_dim + in_channels,
self.decoder_rnn_dim,
bias=True)
self.decoder_rnn = nn.LSTMCell(self.query_dim + in_channels, self.decoder_rnn_dim, bias=True)
self.linear_projection = Linear(self.decoder_rnn_dim + in_channels,
self.frame_channels * self.r_init)
self.linear_projection = Linear(self.decoder_rnn_dim + in_channels, self.frame_channels * self.r_init)
self.stopnet = nn.Sequential(
nn.Dropout(0.1),
Linear(self.decoder_rnn_dim + self.frame_channels * self.r_init,
1,
bias=True,
init_gain='sigmoid'))
Linear(self.decoder_rnn_dim + self.frame_channels * self.r_init, 1, bias=True, init_gain="sigmoid"),
)
self.memory_truncated = None
def set_r(self, new_r):
@ -211,24 +210,18 @@ class Decoder(nn.Module):
def get_go_frame(self, inputs):
B = inputs.size(0)
memory = torch.zeros(1, device=inputs.device).repeat(
B, self.frame_channels * self.r)
memory = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels * self.r)
return memory
def _init_states(self, inputs, mask, keep_states=False):
B = inputs.size(0)
# T = inputs.size(1)
if not keep_states:
self.query = torch.zeros(1, device=inputs.device).repeat(
B, self.query_dim)
self.attention_rnn_cell_state = torch.zeros(
1, device=inputs.device).repeat(B, self.query_dim)
self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat(
B, self.decoder_rnn_dim)
self.decoder_cell = torch.zeros(1, device=inputs.device).repeat(
B, self.decoder_rnn_dim)
self.context = torch.zeros(1, device=inputs.device).repeat(
B, self.encoder_embedding_dim)
self.query = torch.zeros(1, device=inputs.device).repeat(B, self.query_dim)
self.attention_rnn_cell_state = torch.zeros(1, device=inputs.device).repeat(B, self.query_dim)
self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat(B, self.decoder_rnn_dim)
self.decoder_cell = torch.zeros(1, device=inputs.device).repeat(B, self.decoder_rnn_dim)
self.context = torch.zeros(1, device=inputs.device).repeat(B, self.encoder_embedding_dim)
self.inputs = inputs
self.processed_inputs = self.attention.preprocess_inputs(inputs)
self.mask = mask
@ -254,38 +247,36 @@ class Decoder(nn.Module):
def _update_memory(self, memory):
if len(memory.shape) == 2:
return memory[:, self.frame_channels * (self.r - 1):]
return memory[:, :, self.frame_channels * (self.r - 1):]
return memory[:, self.frame_channels * (self.r - 1) :]
return memory[:, :, self.frame_channels * (self.r - 1) :]
def decode(self, memory):
'''
shapes:
- memory: B x r * self.frame_channels
'''
"""
shapes:
- memory: B x r * self.frame_channels
"""
# self.context: B x D_en
# query_input: B x D_en + (r * self.frame_channels)
query_input = torch.cat((memory, self.context), -1)
# self.query and self.attention_rnn_cell_state : B x D_attn_rnn
self.query, self.attention_rnn_cell_state = self.attention_rnn(
query_input, (self.query, self.attention_rnn_cell_state))
self.query = F.dropout(self.query, self.p_attention_dropout,
self.training)
query_input, (self.query, self.attention_rnn_cell_state)
)
self.query = F.dropout(self.query, self.p_attention_dropout, self.training)
self.attention_rnn_cell_state = F.dropout(
self.attention_rnn_cell_state, self.p_attention_dropout,
self.training)
self.attention_rnn_cell_state, self.p_attention_dropout, self.training
)
# B x D_en
self.context = self.attention(self.query, self.inputs,
self.processed_inputs, self.mask)
self.context = self.attention(self.query, self.inputs, self.processed_inputs, self.mask)
# B x (D_en + D_attn_rnn)
decoder_rnn_input = torch.cat((self.query, self.context), -1)
# self.decoder_hidden and self.decoder_cell: B x D_decoder_rnn
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
decoder_rnn_input, (self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout(self.decoder_hidden,
self.p_decoder_dropout, self.training)
decoder_rnn_input, (self.decoder_hidden, self.decoder_cell)
)
self.decoder_hidden = F.dropout(self.decoder_hidden, self.p_decoder_dropout, self.training)
# B x (D_decoder_rnn + D_en)
decoder_hidden_context = torch.cat((self.decoder_hidden, self.context),
dim=1)
decoder_hidden_context = torch.cat((self.decoder_hidden, self.context), dim=1)
# B x (self.r * self.frame_channels)
decoder_output = self.linear_projection(decoder_hidden_context)
# B x (D_decoder_rnn + (self.r * self.frame_channels))
@ -295,7 +286,7 @@ class Decoder(nn.Module):
else:
stop_token = self.stopnet(stopnet_input)
# select outputs for the reduction rate self.r
decoder_output = decoder_output[:, :self.r * self.frame_channels]
decoder_output = decoder_output[:, : self.r * self.frame_channels]
return decoder_output, self.attention.attention_weights, stop_token
def forward(self, inputs, memories, mask):
@ -329,8 +320,7 @@ class Decoder(nn.Module):
stop_tokens += [stop_token.squeeze(1)]
alignments += [attention_weights]
outputs, stop_tokens, alignments = self._parse_outputs(
outputs, stop_tokens, alignments)
outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments)
return outputs, alignments, stop_tokens
def inference(self, inputs):
@ -369,8 +359,7 @@ class Decoder(nn.Module):
memory = self._update_memory(decoder_output)
t += 1
outputs, stop_tokens, alignments = self._parse_outputs(
outputs, stop_tokens, alignments)
outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments)
return outputs, alignments, stop_tokens
@ -404,8 +393,7 @@ class Decoder(nn.Module):
self.memory_truncated = decoder_output
t += 1
outputs, stop_tokens, alignments = self._parse_outputs(
outputs, stop_tokens, alignments)
outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments)
return outputs, alignments, stop_tokens

View File

@ -62,39 +62,27 @@ class AlignTTS(nn.Module):
# pylint: disable=dangerous-default-value
def __init__(
self,
num_chars,
out_channels,
hidden_channels=256,
hidden_channels_dp=256,
encoder_type='fftransformer',
encoder_params={
'hidden_channels_ffn': 1024,
'num_heads': 2,
'num_layers': 6,
'dropout_p': 0.1
},
decoder_type='fftransformer',
decoder_params={
'hidden_channels_ffn': 1024,
'num_heads': 2,
'num_layers': 6,
'dropout_p': 0.1
},
length_scale=1,
num_speakers=0,
external_c=False,
c_in_channels=0):
self,
num_chars,
out_channels,
hidden_channels=256,
hidden_channels_dp=256,
encoder_type="fftransformer",
encoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
decoder_type="fftransformer",
decoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
length_scale=1,
num_speakers=0,
external_c=False,
c_in_channels=0,
):
super().__init__()
self.length_scale = float(length_scale) if isinstance(
length_scale, int) else length_scale
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
self.emb = nn.Embedding(num_chars, hidden_channels)
self.pos_encoder = PositionalEncoding(hidden_channels)
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type,
encoder_params, c_in_channels)
self.decoder = Decoder(out_channels, hidden_channels, decoder_type,
decoder_params)
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels)
self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params)
self.duration_predictor = DurationPredictor(hidden_channels_dp)
self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1)
@ -111,13 +99,13 @@ class AlignTTS(nn.Module):
@staticmethod
def compute_log_probs(mu, log_sigma, y):
# pylint: disable=protected-access, c-extension-no-member
y = y.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D]
mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
y = y.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D]
mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
expanded_y, expanded_mu = torch.broadcast_tensors(y, mu)
exponential = -0.5 * torch.mean(torch._C._nn.mse_loss(
expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2),
dim=-1) # B, L, T
exponential = -0.5 * torch.mean(
torch._C._nn.mse_loss(expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2), dim=-1
) # B, L, T
logp = exponential - 0.5 * log_sigma.mean(dim=-1)
return logp
@ -151,9 +139,7 @@ class AlignTTS(nn.Module):
[1, 0, 0, 0, 0, 0, 0]]
"""
attn = self.convert_dr_to_align(dr, x_mask, y_mask)
o_en_ex = torch.matmul(
attn.squeeze(1).transpose(1, 2), en.transpose(1,
2)).transpose(1, 2)
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
return o_en_ex, attn
def format_durations(self, o_dr_log, x_mask):
@ -170,12 +156,12 @@ class AlignTTS(nn.Module):
def _sum_speaker_embedding(self, x, g):
# project g to decoder dim.
if hasattr(self, 'proj_g'):
if hasattr(self, "proj_g"):
g = self.proj_g(g)
return x + g
def _forward_encoder(self, x, x_lengths, g=None):
if hasattr(self, 'emb_g'):
if hasattr(self, "emb_g"):
g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1]
if g is not None:
@ -187,8 +173,7 @@ class AlignTTS(nn.Module):
x_emb = torch.transpose(x_emb, 1, -1)
# compute sequence masks
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]),
1).to(x.dtype)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
# encoder pass
o_en = self.encoder(x_emb, x_mask)
@ -201,12 +186,11 @@ class AlignTTS(nn.Module):
return o_en, o_en_dp, x_mask, g
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
1).to(o_en_dp.dtype)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
# expand o_en with durations
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
# positional encoding
if hasattr(self, 'pos_encoder'):
if hasattr(self, "pos_encoder"):
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
# speaker embedding
if g is not None:
@ -218,10 +202,8 @@ class AlignTTS(nn.Module):
def _forward_mdn(self, o_en, y, y_lengths, x_mask):
# MAS potentials and alignment
mu, log_sigma = self.mdn_block(o_en)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
1).to(o_en.dtype)
dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask,
y_mask)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask, y_mask)
return dr_mas, mu, log_sigma, logp
def forward(self, x, x_lengths, y, y_lengths, phase=None, g=None): # pylint: disable=unused-argument
@ -237,56 +219,31 @@ class AlignTTS(nn.Module):
if phase == 0:
# train encoder and MDN
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
dr_mas, mu, log_sigma, logp = self._forward_mdn(
o_en, y, y_lengths, x_mask)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
1).to(o_en_dp.dtype)
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
attn = self.convert_dr_to_align(dr_mas, x_mask, y_mask)
elif phase == 1:
# train decoder
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask)
o_de, attn = self._forward_decoder(o_en.detach(),
o_en_dp.detach(),
dr_mas.detach(),
x_mask,
y_lengths,
g=g)
o_de, attn = self._forward_decoder(o_en.detach(), o_en_dp.detach(), dr_mas.detach(), x_mask, y_lengths, g=g)
elif phase == 2:
# train the whole except duration predictor
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
dr_mas, mu, log_sigma, logp = self._forward_mdn(
o_en, y, y_lengths, x_mask)
o_de, attn = self._forward_decoder(o_en,
o_en_dp,
dr_mas,
x_mask,
y_lengths,
g=g)
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
elif phase == 3:
# train duration predictor
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
o_dr_log = self.duration_predictor(x, x_mask)
dr_mas, mu, log_sigma, logp = self._forward_mdn(
o_en, y, y_lengths, x_mask)
o_de, attn = self._forward_decoder(o_en,
o_en_dp,
dr_mas,
x_mask,
y_lengths,
g=g)
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
o_dr_log = o_dr_log.squeeze(1)
else:
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
dr_mas, mu, log_sigma, logp = self._forward_mdn(
o_en, y, y_lengths, x_mask)
o_de, attn = self._forward_decoder(o_en,
o_en_dp,
dr_mas,
x_mask,
y_lengths,
g=g)
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
o_dr_log = o_dr_log.squeeze(1)
dr_mas_log = torch.log(dr_mas + 1).squeeze(1)
return o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp
@ -307,17 +264,14 @@ class AlignTTS(nn.Module):
# duration predictor pass
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
y_lengths = o_dr.sum(1)
o_de, attn = self._forward_decoder(o_en,
o_en_dp,
o_dr,
x_mask,
y_lengths,
g=g)
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
return o_de, attn
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training

View File

@ -34,28 +34,31 @@ class GlowTTS(nn.Module):
encoder_params (dict): encoder module parameters.
external_speaker_embedding_dim (int): channels of external speaker embedding vectors.
"""
def __init__(self,
num_chars,
hidden_channels_enc,
hidden_channels_dec,
use_encoder_prenet,
hidden_channels_dp,
out_channels,
num_flow_blocks_dec=12,
kernel_size_dec=5,
dilation_rate=5,
num_block_layers=4,
dropout_p_dp=0.1,
dropout_p_dec=0.05,
num_speakers=0,
c_in_channels=0,
num_splits=4,
num_squeeze=1,
sigmoid_scale=False,
mean_only=False,
encoder_type="transformer",
encoder_params=None,
external_speaker_embedding_dim=None):
def __init__(
self,
num_chars,
hidden_channels_enc,
hidden_channels_dec,
use_encoder_prenet,
hidden_channels_dp,
out_channels,
num_flow_blocks_dec=12,
kernel_size_dec=5,
dilation_rate=5,
num_block_layers=4,
dropout_p_dp=0.1,
dropout_p_dec=0.05,
num_speakers=0,
c_in_channels=0,
num_splits=4,
num_squeeze=1,
sigmoid_scale=False,
mean_only=False,
encoder_type="transformer",
encoder_params=None,
external_speaker_embedding_dim=None,
):
super().__init__()
self.num_chars = num_chars
@ -78,7 +81,7 @@ class GlowTTS(nn.Module):
# model constants.
self.noise_scale = 0.33 # defines the noise variance applied to the random z vector at inference.
self.length_scale = 1. # scaler for the duration predictor. The larger it is, the slower the speech.
self.length_scale = 1.0 # scaler for the duration predictor. The larger it is, the slower the speech.
self.external_speaker_embedding_dim = external_speaker_embedding_dim
# if is a multispeaker and c_in_channels is 0, set to 256
@ -88,28 +91,32 @@ class GlowTTS(nn.Module):
elif self.external_speaker_embedding_dim:
self.c_in_channels = self.external_speaker_embedding_dim
self.encoder = Encoder(num_chars,
out_channels=out_channels,
hidden_channels=hidden_channels_enc,
hidden_channels_dp=hidden_channels_dp,
encoder_type=encoder_type,
encoder_params=encoder_params,
mean_only=mean_only,
use_prenet=use_encoder_prenet,
dropout_p_dp=dropout_p_dp,
c_in_channels=self.c_in_channels)
self.encoder = Encoder(
num_chars,
out_channels=out_channels,
hidden_channels=hidden_channels_enc,
hidden_channels_dp=hidden_channels_dp,
encoder_type=encoder_type,
encoder_params=encoder_params,
mean_only=mean_only,
use_prenet=use_encoder_prenet,
dropout_p_dp=dropout_p_dp,
c_in_channels=self.c_in_channels,
)
self.decoder = Decoder(out_channels,
hidden_channels_dec,
kernel_size_dec,
dilation_rate,
num_flow_blocks_dec,
num_block_layers,
dropout_p=dropout_p_dec,
num_splits=num_splits,
num_squeeze=num_squeeze,
sigmoid_scale=sigmoid_scale,
c_in_channels=self.c_in_channels)
self.decoder = Decoder(
out_channels,
hidden_channels_dec,
kernel_size_dec,
dilation_rate,
num_flow_blocks_dec,
num_block_layers,
dropout_p=dropout_p_dec,
num_splits=num_splits,
num_squeeze=num_squeeze,
sigmoid_scale=sigmoid_scale,
c_in_channels=self.c_in_channels,
)
if num_speakers > 1 and not external_speaker_embedding_dim:
# speaker embedding layer
@ -119,12 +126,12 @@ class GlowTTS(nn.Module):
@staticmethod
def compute_outputs(attn, o_mean, o_log_scale, x_mask):
# compute final values with the computed alignment
y_mean = torch.matmul(
attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
y_log_scale = torch.matmul(
attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(
1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
1, 2
) # [b, t', t], [b, t, d] -> [b, d, t']
y_log_scale = torch.matmul(attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(1, 2)).transpose(
1, 2
) # [b, t', t], [b, t, d] -> [b, d, t']
# compute total duration with adjustment
o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask
return y_mean, y_log_scale, o_attn_dur
@ -144,37 +151,27 @@ class GlowTTS(nn.Module):
if self.external_speaker_embedding_dim:
g = F.normalize(g).unsqueeze(-1)
else:
g = F.normalize(self.emb_g(g)).unsqueeze(-1)# [b, h, 1]
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
# embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
x_lengths,
g=g)
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
# drop redisual frames wrt num_squeeze and set y_lengths.
y, y_lengths, y_max_length, attn = self.preprocess(
y, y_lengths, y_max_length, None)
y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None)
# create masks
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
1).to(x_mask.dtype)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
# decoder pass
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
# find the alignment path
with torch.no_grad():
o_scale = torch.exp(-2 * o_log_scale)
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
[1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
(z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
z) # [b, t, d] x [b, d, t'] = [b, t, t']
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
[1]).unsqueeze(-1) # [b, t, 1]
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
attn = maximum_path(logp,
attn_mask.squeeze(1)).unsqueeze(1).detach()
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
attn, o_mean, o_log_scale, x_mask)
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
attn = attn.squeeze(1).permute(0, 2, 1)
return z, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur
@ -187,26 +184,20 @@ class GlowTTS(nn.Module):
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
# embedding pass
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
x_lengths,
g=g)
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
# compute output durations
w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale
w_ceil = torch.ceil(w)
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = None
# compute masks
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
1).to(x_mask.dtype)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
# compute attention mask
attn = generate_path(w_ceil.squeeze(1),
attn_mask.squeeze(1)).unsqueeze(1)
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
attn, o_mean, o_log_scale, x_mask)
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) *
self.noise_scale) * y_mask
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.noise_scale) * y_mask
# decoder pass
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
attn = attn.squeeze(1).permute(0, 2, 1)
@ -224,9 +215,11 @@ class GlowTTS(nn.Module):
def store_inverse(self):
self.decoder.store_inverse()
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()
self.store_inverse()

View File

@ -34,42 +34,37 @@ class SpeedySpeech(nn.Module):
external_c (bool, optional): enable external speaker embeddings. Defaults to False.
c_in_channels (int, optional): number of channels in speaker embedding vectors. Defaults to 0.
"""
# pylint: disable=dangerous-default-value
def __init__(
self,
num_chars,
out_channels,
hidden_channels,
positional_encoding=True,
length_scale=1,
encoder_type='residual_conv_bn',
encoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 13
},
decoder_type='residual_conv_bn',
decoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4, 8] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 17
},
num_speakers=0,
external_c=False,
c_in_channels=0):
self,
num_chars,
out_channels,
hidden_channels,
positional_encoding=True,
length_scale=1,
encoder_type="residual_conv_bn",
encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13},
decoder_type="residual_conv_bn",
decoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4, 8] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 17,
},
num_speakers=0,
external_c=False,
c_in_channels=0,
):
super().__init__()
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
self.emb = nn.Embedding(num_chars, hidden_channels)
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type,
encoder_params, c_in_channels)
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels)
if positional_encoding:
self.pos_encoder = PositionalEncoding(hidden_channels)
self.decoder = Decoder(out_channels, hidden_channels,
decoder_type, decoder_params)
self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params)
self.duration_predictor = DurationPredictor(hidden_channels + c_in_channels)
if num_speakers > 1 and not external_c:
@ -97,9 +92,7 @@ class SpeedySpeech(nn.Module):
"""
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype)
o_en_ex = torch.matmul(
attn.squeeze(1).transpose(1, 2), en.transpose(1,
2)).transpose(1, 2)
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
return o_en_ex, attn
def format_durations(self, o_dr_log, x_mask):
@ -116,12 +109,12 @@ class SpeedySpeech(nn.Module):
def _sum_speaker_embedding(self, x, g):
# project g to decoder dim.
if hasattr(self, 'proj_g'):
if hasattr(self, "proj_g"):
g = self.proj_g(g)
return x + g
def _forward_encoder(self, x, x_lengths, g=None):
if hasattr(self, 'emb_g'):
if hasattr(self, "emb_g"):
g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1]
if g is not None:
@ -133,8 +126,7 @@ class SpeedySpeech(nn.Module):
x_emb = torch.transpose(x_emb, 1, -1)
# compute sequence masks
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]),
1).to(x.dtype)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
# encoder pass
o_en = self.encoder(x_emb, x_mask)
@ -147,12 +139,11 @@ class SpeedySpeech(nn.Module):
return o_en, o_en_dp, x_mask, g
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
1).to(o_en_dp.dtype)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
# expand o_en with durations
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
# positional encoding
if hasattr(self, 'pos_encoder'):
if hasattr(self, "pos_encoder"):
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
# speaker embedding
if g is not None:
@ -187,7 +178,7 @@ class SpeedySpeech(nn.Module):
if x.shape[1] < 13:
inference_padding += 13 - x.shape[1]
# pad input to prevent dropping the last word
x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode='constant', value=0)
x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0)
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
# duration predictor pass
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
@ -196,9 +187,11 @@ class SpeedySpeech(nn.Module):
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
return o_de, attn
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval:
self.eval()
assert not self.training

View File

@ -47,45 +47,67 @@ class Tacotron(TacotronAbstract):
memory_size (int, optional): size of the history queue fed to the prenet. Model feeds the last ```memory_size```
output frames to the prenet.
"""
def __init__(self,
num_chars,
num_speakers,
r=5,
postnet_output_dim=1025,
decoder_output_dim=80,
attn_type='original',
attn_win=False,
attn_norm="sigmoid",
prenet_type="original",
prenet_dropout=True,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
attn_K=5,
separate_stopnet=True,
bidirectional_decoder=False,
double_decoder_consistency=False,
ddc_r=None,
encoder_in_features=256,
decoder_in_features=256,
speaker_embedding_dim=None,
gst=False,
gst_embedding_dim=256,
gst_num_heads=4,
gst_style_tokens=10,
memory_size=5,
gst_use_speaker_embedding=False):
super(Tacotron,
self).__init__(num_chars, num_speakers, r, postnet_output_dim,
decoder_output_dim, attn_type, attn_win,
attn_norm, prenet_type, prenet_dropout,
forward_attn, trans_agent, forward_attn_mask,
location_attn, attn_K, separate_stopnet,
bidirectional_decoder, double_decoder_consistency,
ddc_r, encoder_in_features, decoder_in_features,
speaker_embedding_dim, gst, gst_embedding_dim,
gst_num_heads, gst_style_tokens, gst_use_speaker_embedding)
def __init__(
self,
num_chars,
num_speakers,
r=5,
postnet_output_dim=1025,
decoder_output_dim=80,
attn_type="original",
attn_win=False,
attn_norm="sigmoid",
prenet_type="original",
prenet_dropout=True,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
attn_K=5,
separate_stopnet=True,
bidirectional_decoder=False,
double_decoder_consistency=False,
ddc_r=None,
encoder_in_features=256,
decoder_in_features=256,
speaker_embedding_dim=None,
gst=False,
gst_embedding_dim=256,
gst_num_heads=4,
gst_style_tokens=10,
memory_size=5,
gst_use_speaker_embedding=False,
):
super().__init__(
num_chars,
num_speakers,
r,
postnet_output_dim,
decoder_output_dim,
attn_type,
attn_win,
attn_norm,
prenet_type,
prenet_dropout,
forward_attn,
trans_agent,
forward_attn_mask,
location_attn,
attn_K,
separate_stopnet,
bidirectional_decoder,
double_decoder_consistency,
ddc_r,
encoder_in_features,
decoder_in_features,
speaker_embedding_dim,
gst,
gst_embedding_dim,
gst_num_heads,
gst_style_tokens,
gst_use_speaker_embedding,
)
# speaker embedding layers
if self.num_speakers > 1:
@ -96,7 +118,7 @@ class Tacotron(TacotronAbstract):
# speaker and gst embeddings is concat in decoder input
if self.num_speakers > 1:
self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim
self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim
# embedding layer
self.embedding = nn.Embedding(num_chars, 256, padding_idx=0)
@ -104,32 +126,59 @@ class Tacotron(TacotronAbstract):
# base model layers
self.encoder = Encoder(self.encoder_in_features)
self.decoder = Decoder(self.decoder_in_features, decoder_output_dim, r,
memory_size, attn_type, attn_win, attn_norm,
prenet_type, prenet_dropout, forward_attn,
trans_agent, forward_attn_mask, location_attn,
attn_K, separate_stopnet)
self.decoder = Decoder(
self.decoder_in_features,
decoder_output_dim,
r,
memory_size,
attn_type,
attn_win,
attn_norm,
prenet_type,
prenet_dropout,
forward_attn,
trans_agent,
forward_attn_mask,
location_attn,
attn_K,
separate_stopnet,
)
self.postnet = PostCBHG(decoder_output_dim)
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2,
postnet_output_dim)
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim)
# global style token layers
if self.gst:
self.gst_layer = GST(num_mel=80,
num_heads=gst_num_heads,
num_style_tokens=gst_style_tokens,
gst_embedding_dim=self.gst_embedding_dim,
speaker_embedding_dim=speaker_embedding_dim if self.embeddings_per_sample and self.gst_use_speaker_embedding else None)
self.gst_layer = GST(
num_mel=80,
num_heads=gst_num_heads,
num_style_tokens=gst_style_tokens,
gst_embedding_dim=self.gst_embedding_dim,
speaker_embedding_dim=speaker_embedding_dim
if self.embeddings_per_sample and self.gst_use_speaker_embedding
else None,
)
# backward pass decoder
if self.bidirectional_decoder:
self._init_backward_decoder()
# setup DDC
if self.double_decoder_consistency:
self.coarse_decoder = Decoder(
self.decoder_in_features, decoder_output_dim, ddc_r, memory_size,
attn_type, attn_win, attn_norm, prenet_type, prenet_dropout,
forward_attn, trans_agent, forward_attn_mask, location_attn,
attn_K, separate_stopnet)
self.decoder_in_features,
decoder_output_dim,
ddc_r,
memory_size,
attn_type,
attn_win,
attn_norm,
prenet_type,
prenet_dropout,
forward_attn,
trans_agent,
forward_attn_mask,
location_attn,
attn_K,
separate_stopnet,
)
def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None, speaker_embeddings=None):
"""
@ -151,9 +200,9 @@ class Tacotron(TacotronAbstract):
# global style token
if self.gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs,
mel_specs,
speaker_embeddings if self.gst_use_speaker_embedding else None)
encoder_outputs = self.compute_gst(
encoder_outputs, mel_specs, speaker_embeddings if self.gst_use_speaker_embedding else None
)
# speaker embedding
if self.num_speakers > 1:
if not self.embeddings_per_sample:
@ -166,8 +215,7 @@ class Tacotron(TacotronAbstract):
# decoder_outputs: B x decoder_in_features x T_out
# alignments: B x T_in x encoder_in_features
# stop_tokens: B x T_in
decoder_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, input_mask)
decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask)
# sequence masking
if output_mask is not None:
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
@ -182,10 +230,26 @@ class Tacotron(TacotronAbstract):
decoder_outputs = decoder_outputs.transpose(1, 2).contiguous()
if self.bidirectional_decoder:
decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
return (
decoder_outputs,
postnet_outputs,
alignments,
stop_tokens,
decoder_outputs_backward,
alignments_backward,
)
if self.double_decoder_consistency:
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(mel_specs, encoder_outputs, alignments, input_mask)
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
mel_specs, encoder_outputs, alignments, input_mask
)
return (
decoder_outputs,
postnet_outputs,
alignments,
stop_tokens,
decoder_outputs_backward,
alignments_backward,
)
return decoder_outputs, postnet_outputs, alignments, stop_tokens
@torch.no_grad()
@ -194,9 +258,9 @@ class Tacotron(TacotronAbstract):
encoder_outputs = self.encoder(inputs)
if self.gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs,
style_mel,
speaker_embeddings if self.gst_use_speaker_embedding else None)
encoder_outputs = self.compute_gst(
encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None
)
if self.num_speakers > 1:
if not self.embeddings_per_sample:
# B x 1 x speaker_embed_dim
@ -205,8 +269,7 @@ class Tacotron(TacotronAbstract):
# B x 1 x speaker_embed_dim
speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1)
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
encoder_outputs)
decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs)
postnet_outputs = self.postnet(decoder_outputs)
postnet_outputs = self.last_linear(postnet_outputs)
decoder_outputs = decoder_outputs.transpose(1, 2)

View File

@ -44,44 +44,66 @@ class Tacotron2(TacotronAbstract):
gst_style_tokens (int, optional): number of GST tokens. Defaults to 10.
gst_use_speaker_embedding (bool, optional): enable/disable inputing speaker embedding to GST. Defaults to False.
"""
def __init__(self,
num_chars,
num_speakers,
r,
postnet_output_dim=80,
decoder_output_dim=80,
attn_type='original',
attn_win=False,
attn_norm="softmax",
prenet_type="original",
prenet_dropout=True,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
attn_K=5,
separate_stopnet=True,
bidirectional_decoder=False,
double_decoder_consistency=False,
ddc_r=None,
encoder_in_features=512,
decoder_in_features=512,
speaker_embedding_dim=None,
gst=False,
gst_embedding_dim=512,
gst_num_heads=4,
gst_style_tokens=10,
gst_use_speaker_embedding=False):
super(Tacotron2,
self).__init__(num_chars, num_speakers, r, postnet_output_dim,
decoder_output_dim, attn_type, attn_win,
attn_norm, prenet_type, prenet_dropout,
forward_attn, trans_agent, forward_attn_mask,
location_attn, attn_K, separate_stopnet,
bidirectional_decoder, double_decoder_consistency,
ddc_r, encoder_in_features, decoder_in_features,
speaker_embedding_dim, gst, gst_embedding_dim,
gst_num_heads, gst_style_tokens, gst_use_speaker_embedding)
def __init__(
self,
num_chars,
num_speakers,
r,
postnet_output_dim=80,
decoder_output_dim=80,
attn_type="original",
attn_win=False,
attn_norm="softmax",
prenet_type="original",
prenet_dropout=True,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
attn_K=5,
separate_stopnet=True,
bidirectional_decoder=False,
double_decoder_consistency=False,
ddc_r=None,
encoder_in_features=512,
decoder_in_features=512,
speaker_embedding_dim=None,
gst=False,
gst_embedding_dim=512,
gst_num_heads=4,
gst_style_tokens=10,
gst_use_speaker_embedding=False,
):
super().__init__(
num_chars,
num_speakers,
r,
postnet_output_dim,
decoder_output_dim,
attn_type,
attn_win,
attn_norm,
prenet_type,
prenet_dropout,
forward_attn,
trans_agent,
forward_attn_mask,
location_attn,
attn_K,
separate_stopnet,
bidirectional_decoder,
double_decoder_consistency,
ddc_r,
encoder_in_features,
decoder_in_features,
speaker_embedding_dim,
gst,
gst_embedding_dim,
gst_num_heads,
gst_style_tokens,
gst_use_speaker_embedding,
)
# speaker embedding layer
if self.num_speakers > 1:
@ -92,36 +114,63 @@ class Tacotron2(TacotronAbstract):
# speaker and gst embeddings is concat in decoder input
if self.num_speakers > 1:
self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim
self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim
# embedding layer
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
# base model layers
self.encoder = Encoder(self.encoder_in_features)
self.decoder = Decoder(self.decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win,
attn_norm, prenet_type, prenet_dropout,
forward_attn, trans_agent, forward_attn_mask,
location_attn, attn_K, separate_stopnet)
self.decoder = Decoder(
self.decoder_in_features,
self.decoder_output_dim,
r,
attn_type,
attn_win,
attn_norm,
prenet_type,
prenet_dropout,
forward_attn,
trans_agent,
forward_attn_mask,
location_attn,
attn_K,
separate_stopnet,
)
self.postnet = Postnet(self.postnet_output_dim)
# global style token layers
if self.gst:
self.gst_layer = GST(num_mel=80,
num_heads=self.gst_num_heads,
num_style_tokens=self.gst_style_tokens,
gst_embedding_dim=self.gst_embedding_dim,
speaker_embedding_dim=speaker_embedding_dim if self.embeddings_per_sample and self.gst_use_speaker_embedding else None)
self.gst_layer = GST(
num_mel=80,
num_heads=self.gst_num_heads,
num_style_tokens=self.gst_style_tokens,
gst_embedding_dim=self.gst_embedding_dim,
speaker_embedding_dim=speaker_embedding_dim
if self.embeddings_per_sample and self.gst_use_speaker_embedding
else None,
)
# backward pass decoder
if self.bidirectional_decoder:
self._init_backward_decoder()
# setup DDC
if self.double_decoder_consistency:
self.coarse_decoder = Decoder(
self.decoder_in_features, self.decoder_output_dim, ddc_r, attn_type,
attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn,
trans_agent, forward_attn_mask, location_attn, attn_K,
separate_stopnet)
self.decoder_in_features,
self.decoder_output_dim,
ddc_r,
attn_type,
attn_win,
attn_norm,
prenet_type,
prenet_dropout,
forward_attn,
trans_agent,
forward_attn_mask,
location_attn,
attn_K,
separate_stopnet,
)
@staticmethod
def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):
@ -148,9 +197,9 @@ class Tacotron2(TacotronAbstract):
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
if self.gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs,
mel_specs,
speaker_embeddings if self.gst_use_speaker_embedding else None)
encoder_outputs = self.compute_gst(
encoder_outputs, mel_specs, speaker_embeddings if self.gst_use_speaker_embedding else None
)
if self.num_speakers > 1:
if not self.embeddings_per_sample:
# B x 1 x speaker_embed_dim
@ -163,8 +212,7 @@ class Tacotron2(TacotronAbstract):
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
# B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r
decoder_outputs, alignments, stop_tokens = self.decoder(
encoder_outputs, mel_specs, input_mask)
decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask)
# sequence masking
if mel_lengths is not None:
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
@ -175,14 +223,29 @@ class Tacotron2(TacotronAbstract):
if output_mask is not None:
postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs)
# B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
decoder_outputs, postnet_outputs, alignments)
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments)
if self.bidirectional_decoder:
decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
return (
decoder_outputs,
postnet_outputs,
alignments,
stop_tokens,
decoder_outputs_backward,
alignments_backward,
)
if self.double_decoder_consistency:
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(mel_specs, encoder_outputs, alignments, input_mask)
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
mel_specs, encoder_outputs, alignments, input_mask
)
return (
decoder_outputs,
postnet_outputs,
alignments,
stop_tokens,
decoder_outputs_backward,
alignments_backward,
)
return decoder_outputs, postnet_outputs, alignments, stop_tokens
@torch.no_grad()
@ -192,20 +255,18 @@ class Tacotron2(TacotronAbstract):
if self.gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs,
style_mel,
speaker_embeddings if self.gst_use_speaker_embedding else None)
encoder_outputs = self.compute_gst(
encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None
)
if self.num_speakers > 1:
if not self.embeddings_per_sample:
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
encoder_outputs)
decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs)
postnet_outputs = self.postnet(decoder_outputs)
postnet_outputs = decoder_outputs + postnet_outputs
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
decoder_outputs, postnet_outputs, alignments)
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments)
return decoder_outputs, postnet_outputs, alignments, stop_tokens
def inference_truncated(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None):
@ -217,19 +278,17 @@ class Tacotron2(TacotronAbstract):
if self.gst:
# B x gst_dim
encoder_outputs = self.compute_gst(encoder_outputs,
style_mel,
speaker_embeddings if self.gst_use_speaker_embedding else None)
encoder_outputs = self.compute_gst(
encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None
)
if self.num_speakers > 1:
if not self.embeddings_per_sample:
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(
encoder_outputs)
mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(encoder_outputs)
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
mel_outputs, mel_outputs_postnet, alignments)
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(mel_outputs, mel_outputs_postnet, alignments)
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens

View File

@ -8,34 +8,36 @@ from TTS.tts.utils.generic_utils import sequence_mask
class TacotronAbstract(ABC, nn.Module):
def __init__(self,
num_chars,
num_speakers,
r,
postnet_output_dim=80,
decoder_output_dim=80,
attn_type='original',
attn_win=False,
attn_norm="softmax",
prenet_type="original",
prenet_dropout=True,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
attn_K=5,
separate_stopnet=True,
bidirectional_decoder=False,
double_decoder_consistency=False,
ddc_r=None,
encoder_in_features=512,
decoder_in_features=512,
speaker_embedding_dim=None,
gst=False,
gst_embedding_dim=512,
gst_num_heads=4,
gst_style_tokens=10,
gst_use_speaker_embedding=False):
def __init__(
self,
num_chars,
num_speakers,
r,
postnet_output_dim=80,
decoder_output_dim=80,
attn_type="original",
attn_win=False,
attn_norm="softmax",
prenet_type="original",
prenet_dropout=True,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
attn_K=5,
separate_stopnet=True,
bidirectional_decoder=False,
double_decoder_consistency=False,
ddc_r=None,
encoder_in_features=512,
decoder_in_features=512,
speaker_embedding_dim=None,
gst=False,
gst_embedding_dim=512,
gst_num_heads=4,
gst_style_tokens=10,
gst_use_speaker_embedding=False,
):
""" Abstract Tacotron class """
super().__init__()
self.num_chars = num_chars
@ -82,7 +84,7 @@ class TacotronAbstract(ABC, nn.Module):
# global style token
if self.gst:
self.decoder_in_features += gst_embedding_dim # add gst embedding dim
self.decoder_in_features += gst_embedding_dim # add gst embedding dim
self.gst_layer = None
# model states
@ -121,10 +123,12 @@ class TacotronAbstract(ABC, nn.Module):
def inference(self):
pass
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
self.load_state_dict(state['model'])
self.decoder.set_r(state['r'])
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
self.decoder.set_r(state["r"])
if eval:
self.eval()
assert not self.training
@ -149,25 +153,24 @@ class TacotronAbstract(ABC, nn.Module):
def _backward_pass(self, mel_specs, encoder_outputs, mask):
""" Run backwards decoder """
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask)
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask
)
decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous()
return decoder_outputs_b, alignments_b
def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments,
input_mask):
def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask):
""" Double Decoder Consistency """
T = mel_specs.shape[1]
if T % self.coarse_decoder.r > 0:
padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r)
mel_specs = torch.nn.functional.pad(mel_specs,
(0, 0, 0, padding_size, 0, 0))
mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0))
decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder(
encoder_outputs.detach(), mel_specs, input_mask)
encoder_outputs.detach(), mel_specs, input_mask
)
# scale_factor = self.decoder.r_init / self.decoder.r
alignments_backward = torch.nn.functional.interpolate(
alignments_backward.transpose(1, 2),
size=alignments.shape[1],
mode='nearest').transpose(1, 2)
alignments_backward.transpose(1, 2), size=alignments.shape[1], mode="nearest"
).transpose(1, 2)
decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2)
decoder_outputs_backward = decoder_outputs_backward[:, :T, :]
return decoder_outputs_backward, alignments_backward
@ -179,20 +182,17 @@ class TacotronAbstract(ABC, nn.Module):
def compute_speaker_embedding(self, speaker_ids):
""" Compute speaker embedding vectors """
if hasattr(self, "speaker_embedding") and speaker_ids is None:
raise RuntimeError(
" [!] Model has speaker embedding layer but speaker_id is not provided"
)
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
self.speaker_embeddings = self.speaker_embedding(speaker_ids).unsqueeze(1)
if hasattr(self, "speaker_project_mel") and speaker_ids is not None:
self.speaker_embeddings_projected = self.speaker_project_mel(
self.speaker_embeddings).squeeze(1)
self.speaker_embeddings_projected = self.speaker_project_mel(self.speaker_embeddings).squeeze(1)
def compute_gst(self, inputs, style_input, speaker_embedding=None):
""" Compute global style token """
device = inputs.device
if isinstance(style_input, dict):
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
query = torch.zeros(1, 1, self.gst_embedding_dim // 2).to(device)
if speaker_embedding is not None:
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
@ -205,20 +205,18 @@ class TacotronAbstract(ABC, nn.Module):
elif style_input is None:
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
else:
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
return inputs
@staticmethod
def _add_speaker_embedding(outputs, speaker_embeddings):
speaker_embeddings_ = speaker_embeddings.expand(
outputs.size(0), outputs.size(1), -1)
speaker_embeddings_ = speaker_embeddings.expand(outputs.size(0), outputs.size(1), -1)
outputs = outputs + speaker_embeddings_
return outputs
@staticmethod
def _concat_speaker_embedding(outputs, speaker_embeddings):
speaker_embeddings_ = speaker_embeddings.expand(
outputs.size(0), outputs.size(1), -1)
speaker_embeddings_ = speaker_embeddings.expand(outputs.size(0), outputs.size(1), -1)
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
return outputs

View File

@ -1,16 +1,18 @@
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.ops import math_ops
# from tensorflow_addons.seq2seq import BahdanauAttention
# NOTE: linter has a problem with the current TF release
#pylint: disable=no-value-for-parameter
#pylint: disable=unexpected-keyword-arg
# pylint: disable=no-value-for-parameter
# pylint: disable=unexpected-keyword-arg
class Linear(keras.layers.Layer):
def __init__(self, units, use_bias, **kwargs):
super(Linear, self).__init__(**kwargs)
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer')
super().__init__(**kwargs)
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name="linear_layer")
self.activation = keras.layers.ReLU()
def call(self, x):
@ -23,9 +25,11 @@ class Linear(keras.layers.Layer):
class LinearBN(keras.layers.Layer):
def __init__(self, units, use_bias, **kwargs):
super(LinearBN, self).__init__(**kwargs)
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer')
self.batch_normalization = keras.layers.BatchNormalization(axis=-1, momentum=0.90, epsilon=1e-5, name='batch_normalization')
super().__init__(**kwargs)
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name="linear_layer")
self.batch_normalization = keras.layers.BatchNormalization(
axis=-1, momentum=0.90, epsilon=1e-5, name="batch_normalization"
)
self.activation = keras.layers.ReLU()
def call(self, x, training=None):
@ -39,22 +43,21 @@ class LinearBN(keras.layers.Layer):
class Prenet(keras.layers.Layer):
def __init__(self,
prenet_type,
prenet_dropout,
units,
bias,
**kwargs):
super(Prenet, self).__init__(**kwargs)
def __init__(self, prenet_type, prenet_dropout, units, bias, **kwargs):
super().__init__(**kwargs)
self.prenet_type = prenet_type
self.prenet_dropout = prenet_dropout
self.linear_layers = []
if prenet_type == "bn":
self.linear_layers += [LinearBN(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)]
self.linear_layers += [
LinearBN(unit, use_bias=bias, name=f"linear_layer_{idx}") for idx, unit in enumerate(units)
]
elif prenet_type == "original":
self.linear_layers += [Linear(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)]
self.linear_layers += [
Linear(unit, use_bias=bias, name=f"linear_layer_{idx}") for idx, unit in enumerate(units)
]
else:
raise RuntimeError(' [!] Unknown prenet type.')
raise RuntimeError(" [!] Unknown prenet type.")
if prenet_dropout:
self.dropout = keras.layers.Dropout(rate=0.5)
@ -80,11 +83,22 @@ def _sigmoid_norm(score):
class Attention(keras.layers.Layer):
"""TODO: implement forward_attention
TODO: location sensitive attention
TODO: implement attention windowing """
def __init__(self, attn_dim, use_loc_attn, loc_attn_n_filters,
loc_attn_kernel_size, use_windowing, norm, use_forward_attn,
use_trans_agent, use_forward_attn_mask, **kwargs):
super(Attention, self).__init__(**kwargs)
TODO: implement attention windowing"""
def __init__(
self,
attn_dim,
use_loc_attn,
loc_attn_n_filters,
loc_attn_kernel_size,
use_windowing,
norm,
use_forward_attn,
use_trans_agent,
use_forward_attn_mask,
**kwargs,
):
super().__init__(**kwargs)
self.use_loc_attn = use_loc_attn
self.loc_attn_n_filters = loc_attn_n_filters
self.loc_attn_kernel_size = loc_attn_kernel_size
@ -93,20 +107,23 @@ class Attention(keras.layers.Layer):
self.use_forward_attn = use_forward_attn
self.use_trans_agent = use_trans_agent
self.use_forward_attn_mask = use_forward_attn_mask
self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name='query_layer/linear_layer')
self.inputs_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name=f'{self.name}/inputs_layer/linear_layer')
self.v = tf.keras.layers.Dense(1, use_bias=True, name='v/linear_layer')
self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name="query_layer/linear_layer")
self.inputs_layer = tf.keras.layers.Dense(
attn_dim, use_bias=False, name=f"{self.name}/inputs_layer/linear_layer"
)
self.v = tf.keras.layers.Dense(1, use_bias=True, name="v/linear_layer")
if use_loc_attn:
self.location_conv1d = keras.layers.Conv1D(
filters=loc_attn_n_filters,
kernel_size=loc_attn_kernel_size,
padding='same',
padding="same",
use_bias=False,
name='location_layer/location_conv1d')
self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name='location_layer/location_dense')
if norm == 'softmax':
name="location_layer/location_conv1d",
)
self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name="location_layer/location_dense")
if norm == "softmax":
self.norm_func = tf.nn.softmax
elif norm == 'sigmoid':
elif norm == "sigmoid":
self.norm_func = _sigmoid_norm
else:
raise ValueError("Unknown value for attention norm type")
@ -118,30 +135,25 @@ class Attention(keras.layers.Layer):
attention_old = tf.zeros([batch_size, value_length])
states = [attention_cum, attention_old]
if self.use_forward_attn:
alpha = tf.concat([
tf.ones([batch_size, 1]),
tf.zeros([batch_size, value_length])[:, :-1] + 1e-7
], 1)
alpha = tf.concat([tf.ones([batch_size, 1]), tf.zeros([batch_size, value_length])[:, :-1] + 1e-7], 1)
states.append(alpha)
return tuple(states)
def process_values(self, values):
""" cache values for decoder iterations """
#pylint: disable=attribute-defined-outside-init
# pylint: disable=attribute-defined-outside-init
self.processed_values = self.inputs_layer(values)
self.values = values
def get_loc_attn(self, query, states):
""" compute location attention, query layer and
"""compute location attention, query layer and
unnorm. attention weights"""
attention_cum, attention_old = states[:2]
attn_cat = tf.stack([attention_old, attention_cum], axis=2)
processed_query = self.query_layer(tf.expand_dims(query, 1))
processed_attn = self.location_dense(self.location_conv1d(attn_cat))
score = self.v(
tf.nn.tanh(self.processed_values + processed_query +
processed_attn))
score = self.v(tf.nn.tanh(self.processed_values + processed_query + processed_attn))
score = tf.squeeze(score, axis=2)
return score, processed_query
@ -152,14 +164,14 @@ class Attention(keras.layers.Layer):
score = tf.squeeze(score, axis=2)
return score, processed_query
def apply_score_masking(self, score, mask): #pylint: disable=no-self-use
def apply_score_masking(self, score, mask): # pylint: disable=no-self-use
""" ignore sequence paddings """
padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
# Bias so padding positions do not contribute to attention distribution.
score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
score -= 1.0e9 * math_ops.cast(padding_mask, dtype=tf.float32)
return score
def apply_forward_attention(self, alignment, alpha): #pylint: disable=no-self-use
def apply_forward_attention(self, alignment, alpha): # pylint: disable=no-self-use
# forward attention
fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)), constant_values=0.0)
# compute transition potentials
@ -206,7 +218,9 @@ class Attention(keras.layers.Layer):
states = self.update_states(states, scores_norm, attn_weights, new_alpha)
# context_vector shape after sum == (batch_size, hidden_size)
context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False)
context_vector = tf.matmul(
tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False
)
context_vector = tf.squeeze(context_vector, axis=1)
return context_vector, attn_weights, states
@ -230,8 +244,7 @@ class Attention(keras.layers.Layer):
# location_attention_filters=32,
# location_attention_kernel_size=31):
# super(LocationSensitiveAttention,
# self).__init__(units=units,
# super( self).__init__(units=units,
# memory=memory,
# memory_sequence_length=memory_sequence_length,
# normalize=normalize,

View File

@ -5,15 +5,17 @@ from TTS.tts.tf.layers.tacotron.common_layers import Prenet, Attention
# NOTE: linter has a problem with the current TF release
#pylint: disable=no-value-for-parameter
#pylint: disable=unexpected-keyword-arg
# pylint: disable=no-value-for-parameter
# pylint: disable=unexpected-keyword-arg
class ConvBNBlock(keras.layers.Layer):
def __init__(self, filters, kernel_size, activation, **kwargs):
super(ConvBNBlock, self).__init__(**kwargs)
self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding='same', name='convolution1d')
self.batch_normalization = keras.layers.BatchNormalization(axis=2, momentum=0.90, epsilon=1e-5, name='batch_normalization')
self.dropout = keras.layers.Dropout(rate=0.5, name='dropout')
self.activation = keras.layers.Activation(activation, name='activation')
super().__init__(**kwargs)
self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding="same", name="convolution1d")
self.batch_normalization = keras.layers.BatchNormalization(
axis=2, momentum=0.90, epsilon=1e-5, name="batch_normalization"
)
self.dropout = keras.layers.Dropout(rate=0.5, name="dropout")
self.activation = keras.layers.Activation(activation, name="activation")
def call(self, x, training=None):
o = self.convolution1d(x)
@ -25,12 +27,12 @@ class ConvBNBlock(keras.layers.Layer):
class Postnet(keras.layers.Layer):
def __init__(self, output_filters, num_convs, **kwargs):
super(Postnet, self).__init__(**kwargs)
super().__init__(**kwargs)
self.convolutions = []
self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name='convolutions_0'))
self.convolutions.append(ConvBNBlock(512, 5, "tanh", name="convolutions_0"))
for idx in range(1, num_convs - 1):
self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name=f'convolutions_{idx}'))
self.convolutions.append(ConvBNBlock(output_filters, 5, 'linear', name=f'convolutions_{idx+1}'))
self.convolutions.append(ConvBNBlock(512, 5, "tanh", name=f"convolutions_{idx}"))
self.convolutions.append(ConvBNBlock(output_filters, 5, "linear", name=f"convolutions_{idx+1}"))
def call(self, x, training=None):
o = x
@ -41,11 +43,13 @@ class Postnet(keras.layers.Layer):
class Encoder(keras.layers.Layer):
def __init__(self, output_input_dim, **kwargs):
super(Encoder, self).__init__(**kwargs)
super().__init__(**kwargs)
self.convolutions = []
for idx in range(3):
self.convolutions.append(ConvBNBlock(output_input_dim, 5, 'relu', name=f'convolutions_{idx}'))
self.lstm = keras.layers.Bidirectional(keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name='lstm')
self.convolutions.append(ConvBNBlock(output_input_dim, 5, "relu", name=f"convolutions_{idx}"))
self.lstm = keras.layers.Bidirectional(
keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name="lstm"
)
def call(self, x, training=None):
o = x
@ -56,11 +60,27 @@ class Encoder(keras.layers.Layer):
class Decoder(keras.layers.Layer):
#pylint: disable=unused-argument
def __init__(self, frame_dim, r, attn_type, use_attn_win, attn_norm, prenet_type,
prenet_dropout, use_forward_attn, use_trans_agent, use_forward_attn_mask,
use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, enable_tflite, **kwargs):
super(Decoder, self).__init__(**kwargs)
# pylint: disable=unused-argument
def __init__(
self,
frame_dim,
r,
attn_type,
use_attn_win,
attn_norm,
prenet_type,
prenet_dropout,
use_forward_attn,
use_trans_agent,
use_forward_attn_mask,
use_location_attn,
attn_K,
separate_stopnet,
speaker_emb_dim,
enable_tflite,
**kwargs,
):
super().__init__(**kwargs)
self.frame_dim = frame_dim
self.r_init = tf.constant(r, dtype=tf.int32)
self.r = tf.constant(r, dtype=tf.int32)
@ -80,30 +100,31 @@ class Decoder(keras.layers.Layer):
self.p_attention_dropout = 0.1
self.p_decoder_dropout = 0.1
self.prenet = Prenet(prenet_type,
prenet_dropout,
[self.prenet_dim, self.prenet_dim],
bias=False,
name='prenet')
self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name='attention_rnn', )
self.prenet = Prenet(prenet_type, prenet_dropout, [self.prenet_dim, self.prenet_dim], bias=False, name="prenet")
self.attention_rnn = keras.layers.LSTMCell(
self.query_dim,
use_bias=True,
name="attention_rnn",
)
self.attention_rnn_dropout = keras.layers.Dropout(0.5)
# TODO: implement other attn options
self.attention = Attention(attn_dim=self.attn_dim,
use_loc_attn=True,
loc_attn_n_filters=32,
loc_attn_kernel_size=31,
use_windowing=False,
norm=attn_norm,
use_forward_attn=use_forward_attn,
use_trans_agent=use_trans_agent,
use_forward_attn_mask=use_forward_attn_mask,
name='attention')
self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name='decoder_rnn')
self.attention = Attention(
attn_dim=self.attn_dim,
use_loc_attn=True,
loc_attn_n_filters=32,
loc_attn_kernel_size=31,
use_windowing=False,
norm=attn_norm,
use_forward_attn=use_forward_attn,
use_trans_agent=use_trans_agent,
use_forward_attn_mask=use_forward_attn_mask,
name="attention",
)
self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name="decoder_rnn")
self.decoder_rnn_dropout = keras.layers.Dropout(0.5)
self.linear_projection = keras.layers.Dense(self.frame_dim * r, name='linear_projection/linear_layer')
self.stopnet = keras.layers.Dense(1, name='stopnet/linear_layer')
self.linear_projection = keras.layers.Dense(self.frame_dim * r, name="linear_projection/linear_layer")
self.stopnet = keras.layers.Dense(1, name="stopnet/linear_layer")
def set_max_decoder_steps(self, new_max_steps):
self.max_decoder_steps = tf.constant(new_max_steps, dtype=tf.int32)
@ -120,25 +141,31 @@ class Decoder(keras.layers.Layer):
attention_states = self.attention.init_states(batch_size, memory_length)
return zero_frame, zero_context, attention_rnn_state, decoder_rnn_state, attention_states
def step(self, prenet_next, states,
memory_seq_length=None, training=None):
def step(self, prenet_next, states, memory_seq_length=None, training=None):
_, context_next, attention_rnn_state, decoder_rnn_state, attention_states = states
attention_rnn_input = tf.concat([prenet_next, context_next], -1)
attention_rnn_output, attention_rnn_state = \
self.attention_rnn(attention_rnn_input,
attention_rnn_state, training=training)
attention_rnn_output, attention_rnn_state = self.attention_rnn(
attention_rnn_input, attention_rnn_state, training=training
)
attention_rnn_output = self.attention_rnn_dropout(attention_rnn_output, training=training)
context, attention, attention_states = self.attention(attention_rnn_output, attention_states, training=training)
decoder_rnn_input = tf.concat([attention_rnn_output, context], -1)
decoder_rnn_output, decoder_rnn_state = \
self.decoder_rnn(decoder_rnn_input, decoder_rnn_state, training=training)
decoder_rnn_output, decoder_rnn_state = self.decoder_rnn(
decoder_rnn_input, decoder_rnn_state, training=training
)
decoder_rnn_output = self.decoder_rnn_dropout(decoder_rnn_output, training=training)
linear_projection_input = tf.concat([decoder_rnn_output, context], -1)
output_frame = self.linear_projection(linear_projection_input, training=training)
stopnet_input = tf.concat([decoder_rnn_output, output_frame], -1)
stopnet_output = self.stopnet(stopnet_input, training=training)
output_frame = output_frame[:, :self.r * self.frame_dim]
states = (output_frame[:, self.frame_dim * (self.r - 1):], context, attention_rnn_state, decoder_rnn_state, attention_states)
output_frame = output_frame[:, : self.r * self.frame_dim]
states = (
output_frame[:, self.frame_dim * (self.r - 1) :],
context,
attention_rnn_state,
decoder_rnn_state,
attention_states,
)
return output_frame, stopnet_output, states, attention
def decode(self, memory, states, frames, memory_seq_length=None):
@ -157,21 +184,20 @@ class Decoder(keras.layers.Layer):
def _body(step, memory, prenet_output, states, outputs, stop_tokens, attentions):
prenet_next = prenet_output[:, step]
output, stop_token, states, attention = self.step(prenet_next,
states,
memory_seq_length)
output, stop_token, states, attention = self.step(prenet_next, states, memory_seq_length)
outputs = outputs.write(step, output)
attentions = attentions.write(step, attention)
stop_tokens = stop_tokens.write(step, stop_token)
return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions
_, memory, _, states, outputs, stop_tokens, attentions = \
tf.while_loop(lambda *arg: True,
_body,
loop_vars=(step_count, memory, prenet_output,
states, outputs, stop_tokens, attentions),
parallel_iterations=32,
swap_memory=True,
maximum_iterations=num_iter)
_, memory, _, states, outputs, stop_tokens, attentions = tf.while_loop(
lambda *arg: True,
_body,
loop_vars=(step_count, memory, prenet_output, states, outputs, stop_tokens, attentions),
parallel_iterations=32,
swap_memory=True,
maximum_iterations=num_iter,
)
outputs = outputs.stack()
attentions = attentions.stack()
@ -200,10 +226,7 @@ class Decoder(keras.layers.Layer):
def _body(step, memory, states, outputs, stop_tokens, attentions, stop_flag):
frame_next = states[0]
prenet_next = self.prenet(frame_next, training=False)
output, stop_token, states, attention = self.step(prenet_next,
states,
None,
training=False)
output, stop_token, states, attention = self.step(prenet_next, states, None, training=False)
stop_token = tf.math.sigmoid(stop_token)
outputs = outputs.write(step, output)
attentions = attentions.write(step, attention)
@ -213,14 +236,14 @@ class Decoder(keras.layers.Layer):
return step + 1, memory, states, outputs, stop_tokens, attentions, stop_flag
cond = lambda step, m, s, o, st, a, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool))
_, memory, states, outputs, stop_tokens, attentions, stop_flag = \
tf.while_loop(cond,
_body,
loop_vars=(step_count, memory, states, outputs,
stop_tokens, attentions, stop_flag),
parallel_iterations=32,
swap_memory=True,
maximum_iterations=self.max_decoder_steps)
_, memory, states, outputs, stop_tokens, attentions, stop_flag = tf.while_loop(
cond,
_body,
loop_vars=(step_count, memory, states, outputs, stop_tokens, attentions, stop_flag),
parallel_iterations=32,
swap_memory=True,
maximum_iterations=self.max_decoder_steps,
)
outputs = outputs.stack()
attentions = attentions.stack()
@ -238,12 +261,13 @@ class Decoder(keras.layers.Layer):
batch_size is 1"""
# init states
# dynamic_shape is not supported in TFLite
outputs = tf.TensorArray(dtype=tf.float32,
size=self.max_decoder_steps,
element_shape=tf.TensorShape(
[self.output_dim]),
clear_after_read=False,
dynamic_size=False)
outputs = tf.TensorArray(
dtype=tf.float32,
size=self.max_decoder_steps,
element_shape=tf.TensorShape([self.output_dim]),
clear_after_read=False,
dynamic_size=False,
)
# stop_flags = tf.TensorArray(dtype=tf.bool,
# size=self.max_decoder_steps,
# element_shape=tf.TensorShape(
@ -263,10 +287,7 @@ class Decoder(keras.layers.Layer):
def _body(step, memory, states, outputs, stop_flag):
frame_next = states[0]
prenet_next = self.prenet(frame_next, training=False)
output, stop_token, states, _ = self.step(prenet_next,
states,
None,
training=False)
output, stop_token, states, _ = self.step(prenet_next, states, None, training=False)
stop_token = tf.math.sigmoid(stop_token)
stop_flag = tf.greater(stop_token, self.stop_thresh)
stop_flag = tf.reduce_all(stop_flag)
@ -276,24 +297,22 @@ class Decoder(keras.layers.Layer):
return step + 1, memory, states, outputs, stop_flag
cond = lambda step, m, s, o, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool))
step_count, memory, states, outputs, stop_flag = \
tf.while_loop(cond,
_body,
loop_vars=(step_count, memory, states, outputs,
stop_flag),
parallel_iterations=32,
swap_memory=True,
maximum_iterations=self.max_decoder_steps)
step_count, memory, states, outputs, stop_flag = tf.while_loop(
cond,
_body,
loop_vars=(step_count, memory, states, outputs, stop_flag),
parallel_iterations=32,
swap_memory=True,
maximum_iterations=self.max_decoder_steps,
)
outputs = outputs.stack()
outputs = tf.gather(outputs, tf.range(step_count)) # pylint: disable=no-value-for-parameter
outputs = tf.gather(outputs, tf.range(step_count)) # pylint: disable=no-value-for-parameter
outputs = tf.expand_dims(outputs, axis=[0])
outputs = tf.transpose(outputs, [1, 0, 2])
outputs = tf.reshape(outputs, [1, -1, self.frame_dim])
return outputs, stop_tokens, attentions
def call(self, memory, states, frames=None, memory_seq_length=None, training=False):
if training:
return self.decode(memory, states, frames, memory_seq_length)

View File

@ -5,28 +5,30 @@ from TTS.tts.tf.layers.tacotron.tacotron2 import Encoder, Decoder, Postnet
from TTS.tts.tf.utils.tf_utils import shape_list
#pylint: disable=too-many-ancestors, abstract-method
# pylint: disable=too-many-ancestors, abstract-method
class Tacotron2(keras.models.Model):
def __init__(self,
num_chars,
num_speakers,
r,
postnet_output_dim=80,
decoder_output_dim=80,
attn_type='original',
attn_win=False,
attn_norm="softmax",
attn_K=4,
prenet_type="original",
prenet_dropout=True,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
separate_stopnet=True,
bidirectional_decoder=False,
enable_tflite=False):
super(Tacotron2, self).__init__()
def __init__(
self,
num_chars,
num_speakers,
r,
postnet_output_dim=80,
decoder_output_dim=80,
attn_type="original",
attn_win=False,
attn_norm="softmax",
attn_K=4,
prenet_type="original",
prenet_dropout=True,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
separate_stopnet=True,
bidirectional_decoder=False,
enable_tflite=False,
):
super().__init__()
self.r = r
self.decoder_output_dim = decoder_output_dim
self.postnet_output_dim = postnet_output_dim
@ -35,26 +37,28 @@ class Tacotron2(keras.models.Model):
self.speaker_embed_dim = 256
self.enable_tflite = enable_tflite
self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding')
self.encoder = Encoder(512, name='encoder')
self.embedding = keras.layers.Embedding(num_chars, 512, name="embedding")
self.encoder = Encoder(512, name="encoder")
# TODO: most of the decoder args have no use at the momment
self.decoder = Decoder(decoder_output_dim,
r,
attn_type=attn_type,
use_attn_win=attn_win,
attn_norm=attn_norm,
prenet_type=prenet_type,
prenet_dropout=prenet_dropout,
use_forward_attn=forward_attn,
use_trans_agent=trans_agent,
use_forward_attn_mask=forward_attn_mask,
use_location_attn=location_attn,
attn_K=attn_K,
separate_stopnet=separate_stopnet,
speaker_emb_dim=self.speaker_embed_dim,
name='decoder',
enable_tflite=enable_tflite)
self.postnet = Postnet(postnet_output_dim, 5, name='postnet')
self.decoder = Decoder(
decoder_output_dim,
r,
attn_type=attn_type,
use_attn_win=attn_win,
attn_norm=attn_norm,
prenet_type=prenet_type,
prenet_dropout=prenet_dropout,
use_forward_attn=forward_attn,
use_trans_agent=trans_agent,
use_forward_attn_mask=forward_attn_mask,
use_location_attn=location_attn,
attn_K=attn_K,
separate_stopnet=separate_stopnet,
speaker_emb_dim=self.speaker_embed_dim,
name="decoder",
enable_tflite=enable_tflite,
)
self.postnet = Postnet(postnet_output_dim, 5, name="postnet")
@tf.function(experimental_relax_shapes=True)
def call(self, characters, text_lengths=None, frames=None, training=None):
@ -62,14 +66,16 @@ class Tacotron2(keras.models.Model):
return self.training(characters, text_lengths, frames)
if not training:
return self.inference(characters)
raise RuntimeError(' [!] Set model training mode True or False')
raise RuntimeError(" [!] Set model training mode True or False")
def training(self, characters, text_lengths, frames):
B, T = shape_list(characters)
embedding_vectors = self.embedding(characters, training=True)
encoder_output = self.encoder(embedding_vectors, training=True)
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, frames, text_lengths, training=True)
decoder_frames, stop_tokens, attentions = self.decoder(
encoder_output, decoder_states, frames, text_lengths, training=True
)
postnet_frames = self.postnet(decoder_frames, training=True)
output_frames = decoder_frames + postnet_frames
return decoder_frames, output_frames, attentions, stop_tokens
@ -89,7 +95,8 @@ class Tacotron2(keras.models.Model):
experimental_relax_shapes=True,
input_signature=[
tf.TensorSpec([1, None], dtype=tf.int32),
],)
],
)
def inference_tflite(self, characters):
B, T = shape_list(characters)
embedding_vectors = self.embedding(characters, training=False)
@ -101,7 +108,9 @@ class Tacotron2(keras.models.Model):
print(output_frames.shape)
return decoder_frames, output_frames, attentions, stop_tokens
def build_inference(self, ):
def build_inference(
self,
):
# TODO: issue https://github.com/PyCQA/pylint/issues/3613
input_ids = tf.random.uniform(shape=[1, 4], maxval=10, dtype=tf.int32) #pylint: disable=unexpected-keyword-arg
input_ids = tf.random.uniform(shape=[1, 4], maxval=10, dtype=tf.int32) # pylint: disable=unexpected-keyword-arg
self(input_ids)

View File

@ -2,8 +2,9 @@ import numpy as np
import tensorflow as tf
# NOTE: linter has a problem with the current TF release
#pylint: disable=no-value-for-parameter
#pylint: disable=unexpected-keyword-arg
# pylint: disable=no-value-for-parameter
# pylint: disable=unexpected-keyword-arg
def tf_create_dummy_inputs():
""" Create dummy inputs for TF Tacotron2 model """
@ -13,11 +14,11 @@ def tf_create_dummy_inputs():
pad = 1
n_chars = 24
input_ids = tf.random.uniform([batch_size, max_input_length + pad], maxval=n_chars, dtype=tf.int32)
input_lengths = np.random.randint(0, high=max_input_length+1 + pad, size=[batch_size])
input_lengths = np.random.randint(0, high=max_input_length + 1 + pad, size=[batch_size])
input_lengths[-1] = max_input_length
input_lengths = tf.convert_to_tensor(input_lengths, dtype=tf.int32)
mel_outputs = tf.random.uniform(shape=[batch_size, max_mel_length + pad, 80])
mel_lengths = np.random.randint(0, high=max_mel_length+1 + pad, size=[batch_size])
mel_lengths = np.random.randint(0, high=max_mel_length + 1 + pad, size=[batch_size])
mel_lengths[-1] = max_mel_length
mel_lengths = tf.convert_to_tensor(mel_lengths, dtype=tf.int32)
return input_ids, input_lengths, mel_outputs, mel_lengths
@ -31,14 +32,14 @@ def compare_torch_tf(torch_tensor, tf_tensor):
def convert_tf_name(tf_name):
""" Convert certain patterns in TF layer names to Torch patterns """
tf_name_tmp = tf_name
tf_name_tmp = tf_name_tmp.replace(':0', '')
tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_1/recurrent_kernel', '/weight_hh_l0')
tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_2/kernel', '/weight_ih_l1')
tf_name_tmp = tf_name_tmp.replace('/recurrent_kernel', '/weight_hh')
tf_name_tmp = tf_name_tmp.replace('/kernel', '/weight')
tf_name_tmp = tf_name_tmp.replace('/gamma', '/weight')
tf_name_tmp = tf_name_tmp.replace('/beta', '/bias')
tf_name_tmp = tf_name_tmp.replace('/', '.')
tf_name_tmp = tf_name_tmp.replace(":0", "")
tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_1/recurrent_kernel", "/weight_hh_l0")
tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_2/kernel", "/weight_ih_l1")
tf_name_tmp = tf_name_tmp.replace("/recurrent_kernel", "/weight_hh")
tf_name_tmp = tf_name_tmp.replace("/kernel", "/weight")
tf_name_tmp = tf_name_tmp.replace("/gamma", "/weight")
tf_name_tmp = tf_name_tmp.replace("/beta", "/bias")
tf_name_tmp = tf_name_tmp.replace("/", ".")
return tf_name_tmp
@ -47,33 +48,35 @@ def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict):
print(" > Passing weights from Torch to TF ...")
for tf_var in tf_vars:
torch_var_name = var_map_dict[tf_var.name]
print(f' | > {tf_var.name} <-- {torch_var_name}')
print(f" | > {tf_var.name} <-- {torch_var_name}")
# if tuple, it is a bias variable
if not isinstance(torch_var_name, tuple):
torch_layer_name = '.'.join(torch_var_name.split('.')[-2:])
torch_layer_name = ".".join(torch_var_name.split(".")[-2:])
torch_weight = state_dict[torch_var_name]
if 'convolution1d/kernel' in tf_var.name or 'conv1d/kernel' in tf_var.name:
if "convolution1d/kernel" in tf_var.name or "conv1d/kernel" in tf_var.name:
# out_dim, in_dim, filter -> filter, in_dim, out_dim
numpy_weight = torch_weight.permute([2, 1, 0]).detach().cpu().numpy()
elif 'lstm_cell' in tf_var.name and 'kernel' in tf_var.name:
elif "lstm_cell" in tf_var.name and "kernel" in tf_var.name:
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
# if variable is for bidirectional lstm and it is a bias vector there
# needs to be pre-defined two matching torch bias vectors
elif '_lstm/lstm_cell_' in tf_var.name and 'bias' in tf_var.name:
elif "_lstm/lstm_cell_" in tf_var.name and "bias" in tf_var.name:
bias_vectors = [value for key, value in state_dict.items() if key in torch_var_name]
assert len(bias_vectors) == 2
numpy_weight = bias_vectors[0] + bias_vectors[1]
elif 'rnn' in tf_var.name and 'kernel' in tf_var.name:
elif "rnn" in tf_var.name and "kernel" in tf_var.name:
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
elif 'rnn' in tf_var.name and 'bias' in tf_var.name:
elif "rnn" in tf_var.name and "bias" in tf_var.name:
bias_vectors = [value for key, value in state_dict.items() if torch_var_name[:-2] in key]
assert len(bias_vectors) == 2
numpy_weight = bias_vectors[0] + bias_vectors[1]
elif 'linear_layer' in torch_layer_name and 'weight' in torch_var_name:
elif "linear_layer" in torch_layer_name and "weight" in torch_var_name:
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
else:
numpy_weight = torch_weight.detach().cpu().numpy()
assert np.all(tf_var.shape == numpy_weight.shape), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}"
assert np.all(
tf_var.shape == numpy_weight.shape
), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}"
tf.keras.backend.set_value(tf_var, numpy_weight)
return tf_vars

View File

@ -7,20 +7,20 @@ import tensorflow as tf
def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs):
state = {
'model': model.weights,
'optimizer': optimizer,
'step': current_step,
'epoch': epoch,
'date': datetime.date.today().strftime("%B %d, %Y"),
'r': r
"model": model.weights,
"optimizer": optimizer,
"step": current_step,
"epoch": epoch,
"date": datetime.date.today().strftime("%B %d, %Y"),
"r": r,
}
state.update(kwargs)
pickle.dump(state, open(output_path, 'wb'))
pickle.dump(state, open(output_path, "wb"))
def load_checkpoint(model, checkpoint_path):
checkpoint = pickle.load(open(checkpoint_path, 'rb'))
chkp_var_dict = {var.name: var.numpy() for var in checkpoint['model']}
checkpoint = pickle.load(open(checkpoint_path, "rb"))
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights
for tf_var in tf_vars:
layer_name = tf_var.name
@ -32,8 +32,8 @@ def load_checkpoint(model, checkpoint_path):
chkp_var_value = chkp_var_dict[layer_name]
tf.keras.backend.set_value(tf_var, chkp_var_value)
if 'r' in checkpoint.keys():
model.decoder.set_r(checkpoint['r'])
if "r" in checkpoint.keys():
model.decoder.set_r(checkpoint["r"])
return model
@ -45,8 +45,7 @@ def sequence_mask(sequence_length, max_len=None):
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.cuda()
seq_length_expand = (
sequence_length.unsqueeze(1).expand_as(seq_range_expand))
seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand)
# B x T_max
return seq_range_expand < seq_length_expand
@ -62,42 +61,42 @@ def count_parameters(model, c):
try:
return model.count_params()
except RuntimeError:
input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype('int32'))
input_lengths = np.random.randint(100, 129, (8, ))
input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype("int32"))
input_lengths = np.random.randint(100, 129, (8,))
input_lengths[-1] = 128
input_lengths = tf.convert_to_tensor(input_lengths.astype('int32'))
mel_spec = np.random.rand(8, 2 * c.r,
c.audio['num_mels']).astype('float32')
input_lengths = tf.convert_to_tensor(input_lengths.astype("int32"))
mel_spec = np.random.rand(8, 2 * c.r, c.audio["num_mels"]).astype("float32")
mel_spec = tf.convert_to_tensor(mel_spec)
speaker_ids = np.random.randint(
0, 5, (8, )) if c.use_speaker_embedding else None
speaker_ids = np.random.randint(0, 5, (8,)) if c.use_speaker_embedding else None
_ = model(input_dummy, input_lengths, mel_spec, speaker_ids=speaker_ids)
return model.count_params()
def setup_model(num_chars, num_speakers, c, enable_tflite=False):
print(" > Using model: {}".format(c.model))
MyModel = importlib.import_module('TTS.tts.tf.models.' + c.model.lower())
MyModel = importlib.import_module("TTS.tts.tf.models." + c.model.lower())
MyModel = getattr(MyModel, c.model)
if c.model.lower() in "tacotron":
raise NotImplementedError(' [!] Tacotron model is not ready.')
raise NotImplementedError(" [!] Tacotron model is not ready.")
# tacotron2
model = MyModel(num_chars=num_chars,
num_speakers=num_speakers,
r=c.r,
postnet_output_dim=c.audio['num_mels'],
decoder_output_dim=c.audio['num_mels'],
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder,
enable_tflite=enable_tflite)
model = MyModel(
num_chars=num_chars,
num_speakers=num_speakers,
r=c.r,
postnet_output_dim=c.audio["num_mels"],
decoder_output_dim=c.audio["num_mels"],
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder,
enable_tflite=enable_tflite,
)
return model

View File

@ -5,20 +5,20 @@ import tensorflow as tf
def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs):
state = {
'model': model.weights,
'optimizer': optimizer,
'step': current_step,
'epoch': epoch,
'date': datetime.date.today().strftime("%B %d, %Y"),
'r': r
"model": model.weights,
"optimizer": optimizer,
"step": current_step,
"epoch": epoch,
"date": datetime.date.today().strftime("%B %d, %Y"),
"r": r,
}
state.update(kwargs)
pickle.dump(state, open(output_path, 'wb'))
pickle.dump(state, open(output_path, "wb"))
def load_checkpoint(model, checkpoint_path):
checkpoint = pickle.load(open(checkpoint_path, 'rb'))
chkp_var_dict = {var.name: var.numpy() for var in checkpoint['model']}
checkpoint = pickle.load(open(checkpoint_path, "rb"))
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
tf_vars = model.weights
for tf_var in tf_vars:
layer_name = tf_var.name
@ -30,8 +30,8 @@ def load_checkpoint(model, checkpoint_path):
chkp_var_value = chkp_var_dict[layer_name]
tf.keras.backend.set_value(tf_var, chkp_var_value)
if 'r' in checkpoint.keys():
model.decoder.set_r(checkpoint['r'])
if "r" in checkpoint.keys():
model.decoder.set_r(checkpoint["r"])
return model

View File

@ -1,25 +1,20 @@
import tensorflow as tf
def convert_tacotron2_to_tflite(model,
output_path=None,
experimental_converter=True):
def convert_tacotron2_to_tflite(model, output_path=None, experimental_converter=True):
"""Convert Tensorflow Tacotron2 model to TFLite. Save a binary file if output_path is
provided, else return TFLite model."""
concrete_function = model.inference_tflite.get_concrete_function()
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[concrete_function])
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function])
converter.experimental_new_converter = experimental_converter
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
]
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
print(f'Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.')
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
if output_path is not None:
# same model binary if outputpath is provided
with open(output_path, 'wb') as f:
with open(output_path, "wb") as f:
f.write(tflite_model)
return None
return tflite_model

View File

@ -1,4 +1,3 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
@ -31,38 +30,37 @@ def _num2chinese(num: str, big=False, simp=True, o=False, twoalt=False) -> str:
# check num first
nd = str(num)
if abs(float(nd)) >= 1e48:
raise ValueError('number out of range')
if 'e' in nd:
raise ValueError('scientific notation is not supported')
c_symbol = '正负点' if simp else '正負點'
raise ValueError("number out of range")
if "e" in nd:
raise ValueError("scientific notation is not supported")
c_symbol = "正负点" if simp else "正負點"
if o: # formal
twoalt = False
if big:
c_basic = '零壹贰叁肆伍陆柒捌玖' if simp else '零壹貳參肆伍陸柒捌玖'
c_unit1 = '拾佰仟'
c_twoalt = '' if simp else ''
c_basic = "零壹贰叁肆伍陆柒捌玖" if simp else "零壹貳參肆伍陸柒捌玖"
c_unit1 = "拾佰仟"
c_twoalt = "" if simp else ""
else:
c_basic = '〇一二三四五六七八九' if o else '零一二三四五六七八九'
c_unit1 = '十百千'
c_basic = "〇一二三四五六七八九" if o else "零一二三四五六七八九"
c_unit1 = "十百千"
if twoalt:
c_twoalt = '' if simp else ''
c_twoalt = "" if simp else ""
else:
c_twoalt = ''
c_unit2 = '万亿兆京垓秭穰沟涧正载' if simp else '萬億兆京垓秭穰溝澗正載'
revuniq = lambda l: ''.join(k for k, g in itertools.groupby(reversed(l)))
c_twoalt = ""
c_unit2 = "万亿兆京垓秭穰沟涧正载" if simp else "萬億兆京垓秭穰溝澗正載"
revuniq = lambda l: "".join(k for k, g in itertools.groupby(reversed(l)))
nd = str(num)
result = []
if nd[0] == '+':
if nd[0] == "+":
result.append(c_symbol[0])
elif nd[0] == '-':
elif nd[0] == "-":
result.append(c_symbol[1])
if '.' in nd:
integer, remainder = nd.lstrip('+-').split('.')
if "." in nd:
integer, remainder = nd.lstrip("+-").split(".")
else:
integer, remainder = nd.lstrip('+-'), None
integer, remainder = nd.lstrip("+-"), None
if int(integer):
splitted = [integer[max(i - 4, 0):i]
for i in range(len(integer), 0, -4)]
splitted = [integer[max(i - 4, 0) : i] for i in range(len(integer), 0, -4)]
intresult = []
for nu, unit in enumerate(splitted):
# special cases
@ -75,17 +73,17 @@ def _num2chinese(num: str, big=False, simp=True, o=False, twoalt=False) -> str:
ulist = []
unit = unit.zfill(4)
for nc, ch in enumerate(reversed(unit)):
if ch == '0':
if ch == "0":
if ulist: # ???0
ulist.append(c_basic[0])
elif nc == 0:
ulist.append(c_basic[int(ch)])
elif nc == 1 and ch == '1' and unit[1] == '0':
elif nc == 1 and ch == "1" and unit[1] == "0":
# special case for tens
# edit the 'elif' if you don't like
# 十四, 三千零十四, 三千三百一十四
ulist.append(c_unit1[0])
elif nc > 1 and ch == '2':
elif nc > 1 and ch == "2":
ulist.append(c_twoalt + c_unit1[nc - 1])
else:
ulist.append(c_basic[int(ch)] + c_unit1[nc - 1])
@ -99,10 +97,8 @@ def _num2chinese(num: str, big=False, simp=True, o=False, twoalt=False) -> str:
result.append(c_basic[0])
if remainder:
result.append(c_symbol[2])
result.append(''.join(c_basic[int(ch)] for ch in remainder))
return ''.join(result)
result.append("".join(c_basic[int(ch)] for ch in remainder))
return "".join(result)
def _number_replace(match) -> str:
@ -127,5 +123,5 @@ def replace_numbers_to_characters_in_text(text: str) -> str:
Returns:
str: output text
"""
text = re.sub(r'[0-9]+', _number_replace, text)
text = re.sub(r"[0-9]+", _number_replace, text)
return text

View File

@ -9,9 +9,7 @@ import jieba
def _chinese_character_to_pinyin(text: str) -> List[str]:
pinyins = pypinyin.pinyin(
text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True
)
pinyins = pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)
pinyins_flat_list = [item for sublist in pinyins for item in sublist]
return pinyins_flat_list

View File

@ -1,4 +1,3 @@
PINYIN_DICT = {
"a": ["a"],
"ai": ["ai"],

View File

@ -4,8 +4,7 @@ import numpy as np
def _pad_data(x, length):
_pad = 0
assert x.ndim == 1
return np.pad(
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=_pad)
def prepare_data(inputs):
@ -14,12 +13,9 @@ def prepare_data(inputs):
def _pad_tensor(x, length):
_pad = 0.
_pad = 0.0
assert x.ndim == 2
x = np.pad(
x, [[0, 0], [0, length - x.shape[1]]],
mode='constant',
constant_values=_pad)
x = np.pad(x, [[0, 0], [0, length - x.shape[1]]], mode="constant", constant_values=_pad)
return x
@ -31,10 +27,9 @@ def prepare_tensor(inputs, out_steps):
def _pad_stop_target(x, length):
_pad = 0.
_pad = 0.0
assert x.ndim == 1
return np.pad(
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=_pad)
def prepare_stop_target(inputs, out_steps):
@ -46,22 +41,18 @@ def prepare_stop_target(inputs, out_steps):
def pad_per_step(inputs, pad_len):
return np.pad(
inputs, [[0, 0], [0, 0], [0, pad_len]],
mode='constant',
constant_values=0.0)
return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0)
# pylint: disable=attribute-defined-outside-init
class StandardScaler():
class StandardScaler:
def set_stats(self, mean, scale):
self.mean_ = mean
self.scale_ = scale
def reset_stats(self):
delattr(self, 'mean_')
delattr(self, 'scale_')
delattr(self, "mean_")
delattr(self, "scale_")
def transform(self, X):
X = np.asarray(X)

View File

@ -28,277 +28,334 @@ def split_dataset(items):
return items_eval, items
return items[:eval_split_size], items[eval_split_size:]
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
def sequence_mask(sequence_length, max_len=None):
if max_len is None:
max_len = sequence_length.data.max()
seq_range = torch.arange(max_len,
dtype=sequence_length.dtype,
device=sequence_length.device)
seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
# B x T_max
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
def to_camel(text):
text = text.capitalize()
text = re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
text = text.replace('Tts', 'TTS')
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
text = text.replace("Tts", "TTS")
return text
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
print(" > Using model: {}".format(c.model))
MyModel = importlib.import_module('TTS.tts.models.' + c.model.lower())
MyModel = importlib.import_module("TTS.tts.models." + c.model.lower())
MyModel = getattr(MyModel, to_camel(c.model))
if c.model.lower() in "tacotron":
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
num_speakers=num_speakers,
r=c.r,
postnet_output_dim=int(c.audio['fft_size'] / 2 + 1),
decoder_output_dim=c.audio['num_mels'],
gst=c.use_gst,
gst_embedding_dim=c.gst['gst_embedding_dim'],
gst_num_heads=c.gst['gst_num_heads'],
gst_style_tokens=c.gst['gst_style_tokens'],
gst_use_speaker_embedding=c.gst['gst_use_speaker_embedding'],
memory_size=c.memory_size,
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder,
double_decoder_consistency=c.double_decoder_consistency,
ddc_r=c.ddc_r,
speaker_embedding_dim=speaker_embedding_dim)
model = MyModel(
num_chars=num_chars + getattr(c, "add_blank", False),
num_speakers=num_speakers,
r=c.r,
postnet_output_dim=int(c.audio["fft_size"] / 2 + 1),
decoder_output_dim=c.audio["num_mels"],
gst=c.use_gst,
gst_embedding_dim=c.gst["gst_embedding_dim"],
gst_num_heads=c.gst["gst_num_heads"],
gst_style_tokens=c.gst["gst_style_tokens"],
gst_use_speaker_embedding=c.gst["gst_use_speaker_embedding"],
memory_size=c.memory_size,
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder,
double_decoder_consistency=c.double_decoder_consistency,
ddc_r=c.ddc_r,
speaker_embedding_dim=speaker_embedding_dim,
)
elif c.model.lower() == "tacotron2":
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
num_speakers=num_speakers,
r=c.r,
postnet_output_dim=c.audio['num_mels'],
decoder_output_dim=c.audio['num_mels'],
gst=c.use_gst,
gst_embedding_dim=c.gst['gst_embedding_dim'],
gst_num_heads=c.gst['gst_num_heads'],
gst_style_tokens=c.gst['gst_style_tokens'],
gst_use_speaker_embedding=c.gst['gst_use_speaker_embedding'],
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder,
double_decoder_consistency=c.double_decoder_consistency,
ddc_r=c.ddc_r,
speaker_embedding_dim=speaker_embedding_dim)
model = MyModel(
num_chars=num_chars + getattr(c, "add_blank", False),
num_speakers=num_speakers,
r=c.r,
postnet_output_dim=c.audio["num_mels"],
decoder_output_dim=c.audio["num_mels"],
gst=c.use_gst,
gst_embedding_dim=c.gst["gst_embedding_dim"],
gst_num_heads=c.gst["gst_num_heads"],
gst_style_tokens=c.gst["gst_style_tokens"],
gst_use_speaker_embedding=c.gst["gst_use_speaker_embedding"],
attn_type=c.attention_type,
attn_win=c.windowing,
attn_norm=c.attention_norm,
prenet_type=c.prenet_type,
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
attn_K=c.attention_heads,
separate_stopnet=c.separate_stopnet,
bidirectional_decoder=c.bidirectional_decoder,
double_decoder_consistency=c.double_decoder_consistency,
ddc_r=c.ddc_r,
speaker_embedding_dim=speaker_embedding_dim,
)
elif c.model.lower() == "glow_tts":
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
hidden_channels_enc=c['hidden_channels_encoder'],
hidden_channels_dec=c['hidden_channels_decoder'],
hidden_channels_dp=c['hidden_channels_duration_predictor'],
out_channels=c.audio['num_mels'],
encoder_type=c.encoder_type,
encoder_params=c.encoder_params,
use_encoder_prenet=c["use_encoder_prenet"],
num_flow_blocks_dec=12,
kernel_size_dec=5,
dilation_rate=1,
num_block_layers=4,
dropout_p_dec=0.05,
num_speakers=num_speakers,
c_in_channels=0,
num_splits=4,
num_squeeze=2,
sigmoid_scale=False,
mean_only=True,
external_speaker_embedding_dim=speaker_embedding_dim)
model = MyModel(
num_chars=num_chars + getattr(c, "add_blank", False),
hidden_channels_enc=c["hidden_channels_encoder"],
hidden_channels_dec=c["hidden_channels_decoder"],
hidden_channels_dp=c["hidden_channels_duration_predictor"],
out_channels=c.audio["num_mels"],
encoder_type=c.encoder_type,
encoder_params=c.encoder_params,
use_encoder_prenet=c["use_encoder_prenet"],
num_flow_blocks_dec=12,
kernel_size_dec=5,
dilation_rate=1,
num_block_layers=4,
dropout_p_dec=0.05,
num_speakers=num_speakers,
c_in_channels=0,
num_splits=4,
num_squeeze=2,
sigmoid_scale=False,
mean_only=True,
external_speaker_embedding_dim=speaker_embedding_dim,
)
elif c.model.lower() == "speedy_speech":
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
out_channels=c.audio['num_mels'],
hidden_channels=c['hidden_channels'],
positional_encoding=c['positional_encoding'],
encoder_type=c['encoder_type'],
encoder_params=c['encoder_params'],
decoder_type=c['decoder_type'],
decoder_params=c['decoder_params'],
c_in_channels=0)
model = MyModel(
num_chars=num_chars + getattr(c, "add_blank", False),
out_channels=c.audio["num_mels"],
hidden_channels=c["hidden_channels"],
positional_encoding=c["positional_encoding"],
encoder_type=c["encoder_type"],
encoder_params=c["encoder_params"],
decoder_type=c["decoder_type"],
decoder_params=c["decoder_params"],
c_in_channels=0,
)
elif c.model.lower() == "align_tts":
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
out_channels=c.audio['num_mels'],
hidden_channels=c['hidden_channels'],
hidden_channels_dp=c['hidden_channels_dp'],
encoder_type=c['encoder_type'],
encoder_params=c['encoder_params'],
decoder_type=c['decoder_type'],
decoder_params=c['decoder_params'],
c_in_channels=0)
model = MyModel(
num_chars=num_chars + getattr(c, "add_blank", False),
out_channels=c.audio["num_mels"],
hidden_channels=c["hidden_channels"],
hidden_channels_dp=c["hidden_channels_dp"],
encoder_type=c["encoder_type"],
encoder_params=c["encoder_params"],
decoder_type=c["decoder_type"],
decoder_params=c["decoder_params"],
c_in_channels=0,
)
return model
def is_tacotron(c):
return 'tacotron' in c['model'].lower()
return "tacotron" in c["model"].lower()
def check_config_tts(c):
check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts', 'speedy_speech', 'align_tts'], restricted=True, val_type=str)
check_argument('run_name', c, restricted=True, val_type=str)
check_argument('run_description', c, val_type=str)
check_argument(
"model",
c,
enum_list=["tacotron", "tacotron2", "glow_tts", "speedy_speech", "align_tts"],
restricted=True,
val_type=str,
)
check_argument("run_name", c, restricted=True, val_type=str)
check_argument("run_description", c, val_type=str)
# AUDIO
check_argument('audio', c, restricted=True, val_type=dict)
check_argument("audio", c, restricted=True, val_type=dict)
# audio processing parameters
check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
check_argument("num_mels", c["audio"], restricted=True, val_type=int, min_val=10, max_val=2056)
check_argument("fft_size", c["audio"], restricted=True, val_type=int, min_val=128, max_val=4058)
check_argument("sample_rate", c["audio"], restricted=True, val_type=int, min_val=512, max_val=100000)
check_argument(
"frame_length_ms",
c["audio"],
restricted=True,
val_type=float,
min_val=10,
max_val=1000,
alternative="win_length",
)
check_argument(
"frame_shift_ms", c["audio"], restricted=True, val_type=float, min_val=1, max_val=1000, alternative="hop_length"
)
check_argument("preemphasis", c["audio"], restricted=True, val_type=float, min_val=0, max_val=1)
check_argument("min_level_db", c["audio"], restricted=True, val_type=int, min_val=-1000, max_val=10)
check_argument("ref_level_db", c["audio"], restricted=True, val_type=int, min_val=0, max_val=1000)
check_argument("power", c["audio"], restricted=True, val_type=float, min_val=1, max_val=5)
check_argument("griffin_lim_iters", c["audio"], restricted=True, val_type=int, min_val=10, max_val=1000)
# vocabulary parameters
check_argument('characters', c, restricted=False, val_type=dict)
check_argument('pad', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
check_argument('eos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
check_argument('bos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
check_argument('characters', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
check_argument('phonemes', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys() and c['use_phonemes'], val_type=str)
check_argument('punctuations', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
check_argument("characters", c, restricted=False, val_type=dict)
check_argument(
"pad", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str
)
check_argument(
"eos", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str
)
check_argument(
"bos", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str
)
check_argument(
"characters",
c["characters"] if "characters" in c.keys() else {},
restricted="characters" in c.keys(),
val_type=str,
)
check_argument(
"phonemes",
c["characters"] if "characters" in c.keys() else {},
restricted="characters" in c.keys() and c["use_phonemes"],
val_type=str,
)
check_argument(
"punctuations",
c["characters"] if "characters" in c.keys() else {},
restricted="characters" in c.keys(),
val_type=str,
)
# normalization parameters
check_argument('signal_norm', c['audio'], restricted=True, val_type=bool)
check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool)
check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000)
check_argument('clip_norm', c['audio'], restricted=True, val_type=bool)
check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000)
check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0)
check_argument('spec_gain', c['audio'], restricted=True, val_type=[int, float], min_val=1, max_val=100)
check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool)
check_argument('trim_db', c['audio'], restricted=True, val_type=int)
check_argument("signal_norm", c["audio"], restricted=True, val_type=bool)
check_argument("symmetric_norm", c["audio"], restricted=True, val_type=bool)
check_argument("max_norm", c["audio"], restricted=True, val_type=float, min_val=0.1, max_val=1000)
check_argument("clip_norm", c["audio"], restricted=True, val_type=bool)
check_argument("mel_fmin", c["audio"], restricted=True, val_type=float, min_val=0.0, max_val=1000)
check_argument("mel_fmax", c["audio"], restricted=True, val_type=float, min_val=500.0)
check_argument("spec_gain", c["audio"], restricted=True, val_type=[int, float], min_val=1, max_val=100)
check_argument("do_trim_silence", c["audio"], restricted=True, val_type=bool)
check_argument("trim_db", c["audio"], restricted=True, val_type=int)
# training parameters
check_argument('batch_size', c, restricted=True, val_type=int, min_val=1)
check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1)
check_argument('r', c, restricted=True, val_type=int, min_val=1)
check_argument('gradual_training', c, restricted=False, val_type=list)
check_argument('mixed_precision', c, restricted=False, val_type=bool)
check_argument("batch_size", c, restricted=True, val_type=int, min_val=1)
check_argument("eval_batch_size", c, restricted=True, val_type=int, min_val=1)
check_argument("r", c, restricted=True, val_type=int, min_val=1)
check_argument("gradual_training", c, restricted=False, val_type=list)
check_argument("mixed_precision", c, restricted=False, val_type=bool)
# check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
# loss parameters
check_argument('loss_masking', c, restricted=True, val_type=bool)
if c['model'].lower() in ['tacotron', 'tacotron2']:
check_argument('decoder_loss_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('postnet_loss_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('postnet_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('decoder_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0)
if c['model'].lower in ["speedy_speech", "align_tts"]:
check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0)
check_argument("loss_masking", c, restricted=True, val_type=bool)
if c["model"].lower() in ["tacotron", "tacotron2"]:
check_argument("decoder_loss_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("postnet_loss_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("postnet_diff_spec_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("decoder_diff_spec_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("decoder_ssim_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("postnet_ssim_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("ga_alpha", c, restricted=True, val_type=float, min_val=0)
if c["model"].lower in ["speedy_speech", "align_tts"]:
check_argument("ssim_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("l1_alpha", c, restricted=True, val_type=float, min_val=0)
check_argument("huber_alpha", c, restricted=True, val_type=float, min_val=0)
# validation parameters
check_argument('run_eval', c, restricted=True, val_type=bool)
check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0)
check_argument('test_sentences_file', c, restricted=False, val_type=str)
check_argument("run_eval", c, restricted=True, val_type=bool)
check_argument("test_delay_epochs", c, restricted=True, val_type=int, min_val=0)
check_argument("test_sentences_file", c, restricted=False, val_type=str)
# optimizer
check_argument('noam_schedule', c, restricted=False, val_type=bool)
check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0)
check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
check_argument('lr', c, restricted=True, val_type=float, min_val=0)
check_argument('wd', c, restricted=is_tacotron(c), val_type=float, min_val=0)
check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
check_argument('seq_len_norm', c, restricted=is_tacotron(c), val_type=bool)
check_argument("noam_schedule", c, restricted=False, val_type=bool)
check_argument("grad_clip", c, restricted=True, val_type=float, min_val=0.0)
check_argument("epochs", c, restricted=True, val_type=int, min_val=1)
check_argument("lr", c, restricted=True, val_type=float, min_val=0)
check_argument("wd", c, restricted=is_tacotron(c), val_type=float, min_val=0)
check_argument("warmup_steps", c, restricted=True, val_type=int, min_val=0)
check_argument("seq_len_norm", c, restricted=is_tacotron(c), val_type=bool)
# tacotron prenet
check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1)
check_argument('prenet_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['original', 'bn'])
check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool)
check_argument("memory_size", c, restricted=is_tacotron(c), val_type=int, min_val=-1)
check_argument("prenet_type", c, restricted=is_tacotron(c), val_type=str, enum_list=["original", "bn"])
check_argument("prenet_dropout", c, restricted=is_tacotron(c), val_type=bool)
# attention
check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original', 'dynamic_convolution'])
check_argument('attention_heads', c, restricted=is_tacotron(c), val_type=int)
check_argument('attention_norm', c, restricted=is_tacotron(c), val_type=str, enum_list=['sigmoid', 'softmax'])
check_argument('windowing', c, restricted=is_tacotron(c), val_type=bool)
check_argument('use_forward_attn', c, restricted=is_tacotron(c), val_type=bool)
check_argument('forward_attn_mask', c, restricted=is_tacotron(c), val_type=bool)
check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
check_argument('location_attn', c, restricted=is_tacotron(c), val_type=bool)
check_argument('bidirectional_decoder', c, restricted=is_tacotron(c), val_type=bool)
check_argument('double_decoder_consistency', c, restricted=is_tacotron(c), val_type=bool)
check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int)
check_argument(
"attention_type",
c,
restricted=is_tacotron(c),
val_type=str,
enum_list=["graves", "original", "dynamic_convolution"],
)
check_argument("attention_heads", c, restricted=is_tacotron(c), val_type=int)
check_argument("attention_norm", c, restricted=is_tacotron(c), val_type=str, enum_list=["sigmoid", "softmax"])
check_argument("windowing", c, restricted=is_tacotron(c), val_type=bool)
check_argument("use_forward_attn", c, restricted=is_tacotron(c), val_type=bool)
check_argument("forward_attn_mask", c, restricted=is_tacotron(c), val_type=bool)
check_argument("transition_agent", c, restricted=is_tacotron(c), val_type=bool)
check_argument("transition_agent", c, restricted=is_tacotron(c), val_type=bool)
check_argument("location_attn", c, restricted=is_tacotron(c), val_type=bool)
check_argument("bidirectional_decoder", c, restricted=is_tacotron(c), val_type=bool)
check_argument("double_decoder_consistency", c, restricted=is_tacotron(c), val_type=bool)
check_argument("ddc_r", c, restricted="double_decoder_consistency" in c.keys(), min_val=1, max_val=7, val_type=int)
if c['model'].lower() in ['tacotron', 'tacotron2']:
if c["model"].lower() in ["tacotron", "tacotron2"]:
# stopnet
check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool)
check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool)
check_argument("stopnet", c, restricted=is_tacotron(c), val_type=bool)
check_argument("separate_stopnet", c, restricted=is_tacotron(c), val_type=bool)
# Model Parameters for non-tacotron models
if c['model'].lower in ["speedy_speech", "align_tts"]:
check_argument('positional_encoding', c, restricted=True, val_type=type)
check_argument('encoder_type', c, restricted=True, val_type=str)
check_argument('encoder_params', c, restricted=True, val_type=dict)
check_argument('decoder_residual_conv_bn_params', c, restricted=True, val_type=dict)
if c["model"].lower in ["speedy_speech", "align_tts"]:
check_argument("positional_encoding", c, restricted=True, val_type=type)
check_argument("encoder_type", c, restricted=True, val_type=str)
check_argument("encoder_params", c, restricted=True, val_type=dict)
check_argument("decoder_residual_conv_bn_params", c, restricted=True, val_type=dict)
# GlowTTS parameters
check_argument('encoder_type', c, restricted=not is_tacotron(c), val_type=str)
check_argument("encoder_type", c, restricted=not is_tacotron(c), val_type=str)
# tensorboard
check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1)
check_argument('save_step', c, restricted=True, val_type=int, min_val=1)
check_argument('checkpoint', c, restricted=True, val_type=bool)
check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
check_argument("print_step", c, restricted=True, val_type=int, min_val=1)
check_argument("tb_plot_step", c, restricted=True, val_type=int, min_val=1)
check_argument("save_step", c, restricted=True, val_type=int, min_val=1)
check_argument("checkpoint", c, restricted=True, val_type=bool)
check_argument("tb_model_param_stats", c, restricted=True, val_type=bool)
# dataloading
# pylint: disable=import-outside-toplevel
from TTS.tts.utils.text import cleaners
check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners))
check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)
check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0)
check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0)
check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0)
check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0)
check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10)
check_argument('compute_input_seq_cache', c, restricted=True, val_type=bool)
check_argument("text_cleaner", c, restricted=True, val_type=str, enum_list=dir(cleaners))
check_argument("enable_eos_bos_chars", c, restricted=True, val_type=bool)
check_argument("num_loader_workers", c, restricted=True, val_type=int, min_val=0)
check_argument("num_val_loader_workers", c, restricted=True, val_type=int, min_val=0)
check_argument("batch_group_size", c, restricted=True, val_type=int, min_val=0)
check_argument("min_seq_len", c, restricted=True, val_type=int, min_val=0)
check_argument("max_seq_len", c, restricted=True, val_type=int, min_val=10)
check_argument("compute_input_seq_cache", c, restricted=True, val_type=bool)
# paths
check_argument('output_path', c, restricted=True, val_type=str)
check_argument("output_path", c, restricted=True, val_type=str)
# multi-speaker and gst
check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
check_argument('use_external_speaker_embedding_file', c, restricted=c['use_speaker_embedding'], val_type=bool)
check_argument('external_speaker_embedding_file', c, restricted=c['use_external_speaker_embedding_file'], val_type=str)
if c['model'].lower() in ['tacotron', 'tacotron2'] and c['use_gst']:
check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool)
check_argument('gst', c, restricted=is_tacotron(c), val_type=dict)
check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict])
check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)
check_argument('gst_use_speaker_embedding', c['gst'], restricted=is_tacotron(c), val_type=bool)
check_argument('gst_num_heads', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10)
check_argument('gst_style_tokens', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000)
check_argument("use_speaker_embedding", c, restricted=True, val_type=bool)
check_argument("use_external_speaker_embedding_file", c, restricted=c["use_speaker_embedding"], val_type=bool)
check_argument(
"external_speaker_embedding_file", c, restricted=c["use_external_speaker_embedding_file"], val_type=str
)
if c["model"].lower() in ["tacotron", "tacotron2"] and c["use_gst"]:
check_argument("use_gst", c, restricted=is_tacotron(c), val_type=bool)
check_argument("gst", c, restricted=is_tacotron(c), val_type=dict)
check_argument("gst_style_input", c["gst"], restricted=is_tacotron(c), val_type=[str, dict])
check_argument("gst_embedding_dim", c["gst"], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)
check_argument("gst_use_speaker_embedding", c["gst"], restricted=is_tacotron(c), val_type=bool)
check_argument("gst_num_heads", c["gst"], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10)
check_argument("gst_style_tokens", c["gst"], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000)
# datasets - checking only the first entry
check_argument('datasets', c, restricted=True, val_type=list)
for dataset_entry in c['datasets']:
check_argument('name', dataset_entry, restricted=True, val_type=str)
check_argument('path', dataset_entry, restricted=True, val_type=str)
check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
check_argument("datasets", c, restricted=True, val_type=list)
for dataset_entry in c["datasets"]:
check_argument("name", dataset_entry, restricted=True, val_type=str)
check_argument("path", dataset_entry, restricted=True, val_type=str)
check_argument("meta_file_train", dataset_entry, restricted=True, val_type=[str, list])
check_argument("meta_file_val", dataset_entry, restricted=True, val_type=str)

View File

@ -6,7 +6,6 @@ import pickle as pickle_tts
from TTS.utils.io import RenamingUnpickler
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
"""Load ```TTS.tts.models``` checkpoints.
@ -20,33 +19,25 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False
[type]: [description]
"""
try:
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
except ModuleNotFoundError:
pickle_tts.Unpickler = RenamingUnpickler
state = torch.load(checkpoint_path, map_location=torch.device('cpu'), pickle_module=pickle_tts)
model.load_state_dict(state['model'])
if amp and 'amp' in state:
amp.load_state_dict(state['amp'])
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
model.load_state_dict(state["model"])
if amp and "amp" in state:
amp.load_state_dict(state["amp"])
if use_cuda:
model.cuda()
# set model stepsize
if hasattr(model.decoder, 'r'):
model.decoder.set_r(state['r'])
print(" > Model r: ", state['r'])
if hasattr(model.decoder, "r"):
model.decoder.set_r(state["r"])
print(" > Model r: ", state["r"])
if eval:
model.eval()
return model, state
def save_model(model,
optimizer,
current_step,
epoch,
r,
output_path,
characters,
amp_state_dict=None,
**kwargs):
def save_model(model, optimizer, current_step, epoch, r, output_path, characters, amp_state_dict=None, **kwargs):
"""Save ```TTS.tts.models``` states with extra fields.
Args:
@ -59,27 +50,26 @@ def save_model(model,
characters (list): list of characters used in the model.
amp_state_dict (state_dict, optional): Apex.amp state dict if Apex is enabled. Defaults to None.
"""
if hasattr(model, 'module'):
if hasattr(model, "module"):
model_state = model.module.state_dict()
else:
model_state = model.state_dict()
state = {
'model': model_state,
'optimizer': optimizer.state_dict() if optimizer is not None else None,
'step': current_step,
'epoch': epoch,
'date': datetime.date.today().strftime("%B %d, %Y"),
'r': r,
'characters': characters
"model": model_state,
"optimizer": optimizer.state_dict() if optimizer is not None else None,
"step": current_step,
"epoch": epoch,
"date": datetime.date.today().strftime("%B %d, %Y"),
"r": r,
"characters": characters,
}
if amp_state_dict:
state['amp'] = amp_state_dict
state["amp"] = amp_state_dict
state.update(kwargs)
torch.save(state, output_path)
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder,
characters, **kwargs):
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, characters, **kwargs):
"""Save model checkpoint, intended for saving checkpoints at training.
Args:
@ -91,14 +81,15 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder,
output_path (str): output path to save the model file.
characters (list): list of characters used in the model.
"""
file_name = 'checkpoint_{}.pth.tar'.format(current_step)
file_name = "checkpoint_{}.pth.tar".format(current_step)
checkpoint_path = os.path.join(output_folder, file_name)
print(" > CHECKPOINT : {}".format(checkpoint_path))
save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, **kwargs)
def save_best_model(target_loss, best_loss, model, optimizer, current_step,
epoch, r, output_folder, characters, **kwargs):
def save_best_model(
target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, characters, **kwargs
):
"""Save model checkpoint, intended for saving the best model after each epoch.
It compares the current model loss with the best loss so far and saves the
model if the current loss is better.
@ -118,9 +109,11 @@ def save_best_model(target_loss, best_loss, model, optimizer, current_step,
float: updated current best loss.
"""
if target_loss < best_loss:
file_name = 'best_model.pth.tar'
file_name = "best_model.pth.tar"
checkpoint_path = os.path.join(output_folder, file_name)
print(" >> BEST MODEL : {}".format(checkpoint_path))
save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, model_loss=target_loss, **kwargs)
save_model(
model, optimizer, current_step, epoch, r, checkpoint_path, characters, model_loss=target_loss, **kwargs
)
best_loss = target_loss
return best_loss

View File

@ -10,7 +10,7 @@ def make_speakers_json_path(out_path):
def load_speaker_mapping(out_path):
"""Loads speaker mapping if already present."""
try:
if os.path.splitext(out_path)[1] == '.json':
if os.path.splitext(out_path)[1] == ".json":
json_file = out_path
else:
json_file = make_speakers_json_path(out_path)
@ -19,6 +19,7 @@ def load_speaker_mapping(out_path):
except FileNotFoundError:
return {}
def save_speaker_mapping(out_path, speaker_mapping):
"""Saves speaker mapping if not yet present."""
speakers_json_path = make_speakers_json_path(out_path)
@ -31,40 +32,49 @@ def get_speakers(items):
speakers = {e[2] for e in items}
return sorted(speakers)
def parse_speakers(c, args, meta_data_train, OUT_PATH):
""" Returns number of speakers, speaker embedding shape and speaker mapping"""
if c.use_speaker_embedding:
speakers = get_speakers(meta_data_train)
if args.restore_path:
if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file
if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
if not speaker_mapping:
print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file")
print(
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
)
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
if not speaker_mapping:
raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file")
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
elif not c.use_external_speaker_embedding_file: # if restore checkpoint and don't use External Embedding file
raise RuntimeError(
"You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file"
)
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]["embedding"])
elif (
not c.use_external_speaker_embedding_file
): # if restore checkpoint and don't use External Embedding file
prev_out_path = os.path.dirname(args.restore_path)
speaker_mapping = load_speaker_mapping(prev_out_path)
speaker_embedding_dim = None
assert all([speaker in speaker_mapping
for speaker in speakers]), "As of now you, you cannot " \
"introduce new speakers to " \
"a previously trained model."
elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: # if start new train using External Embedding file
assert all(speaker in speaker_mapping for speaker in speakers), (
"As of now you, you cannot " "introduce new speakers to " "a previously trained model."
)
elif (
c.use_external_speaker_embedding_file and c.external_speaker_embedding_file
): # if start new train using External Embedding file
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
elif c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file: # if start new train using External Embedding file and don't pass external embedding file
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]["embedding"])
elif (
c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file
): # if start new train using External Embedding file and don't pass external embedding file
raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
else: # if start new train and don't use External Embedding file
else: # if start new train and don't use External Embedding file
speaker_mapping = {name: i for i, name in enumerate(speakers)}
speaker_embedding_dim = None
save_speaker_mapping(OUT_PATH, speaker_mapping)
num_speakers = len(speaker_mapping)
print(" > Training with {} speakers: {}".format(
len(speakers), ", ".join(speakers)))
print(" > Training with {} speakers: {}".format(len(speakers), ", ".join(speakers)))
else:
num_speakers = 0
speaker_embedding_dim = None

View File

@ -8,8 +8,9 @@ from torch.autograd import Variable
def gaussian(window_size, sigma):
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
return gauss/gauss.sum()
gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) for x in range(window_size)])
return gauss / gauss.sum()
def create_window(window_size, channel):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
@ -24,25 +25,22 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True):
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1*mu2
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(
img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(
img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
sigma12 = F.conv2d(
img1 * img2, window, padding=window_size // 2,
groups=channel) - mu1_mu2
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
C1 = 0.01**2
C2 = 0.03**2
C1 = 0.01 ** 2
C2 = 0.03 ** 2
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
if size_average:
return ssim_map.mean()
return ssim_map.mean(1).mean(1).mean(1)
class SSIM(torch.nn.Module):
def __init__(self, window_size=11, size_average=True):
super().__init__()
@ -66,7 +64,6 @@ class SSIM(torch.nn.Module):
self.window = window
self.channel = channel
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)

View File

@ -1,8 +1,10 @@
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import pkg_resources
installed = {pkg.key for pkg in pkg_resources.working_set} #pylint: disable=not-an-iterable
if 'tensorflow' in installed or 'tensorflow-gpu' in installed:
installed = {pkg.key for pkg in pkg_resources.working_set} # pylint: disable=not-an-iterable
if "tensorflow" in installed or "tensorflow-gpu" in installed:
import tensorflow as tf
import torch
import numpy as np
@ -14,19 +16,26 @@ def text_to_seqvec(text, CONFIG):
# text ot phonemes to sequence vector
if CONFIG.use_phonemes:
seq = np.asarray(
phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language,
CONFIG.enable_eos_bos_chars,
tp=CONFIG.characters if 'characters' in CONFIG.keys() else None,
add_blank=CONFIG['add_blank'] if 'add_blank' in CONFIG.keys() else False),
dtype=np.int32)
phoneme_to_sequence(
text,
text_cleaner,
CONFIG.phoneme_language,
CONFIG.enable_eos_bos_chars,
tp=CONFIG.characters if "characters" in CONFIG.keys() else None,
add_blank=CONFIG["add_blank"] if "add_blank" in CONFIG.keys() else False,
),
dtype=np.int32,
)
else:
seq = np.asarray(text_to_sequence(
text,
text_cleaner,
tp=CONFIG.characters if 'characters' in CONFIG.keys() else None,
add_blank=CONFIG['add_blank']
if 'add_blank' in CONFIG.keys() else False),
dtype=np.int32)
seq = np.asarray(
text_to_sequence(
text,
text_cleaner,
tp=CONFIG.characters if "characters" in CONFIG.keys() else None,
add_blank=CONFIG["add_blank"] if "add_blank" in CONFIG.keys() else False,
),
dtype=np.int32,
)
return seq
@ -47,86 +56,95 @@ def numpy_to_tf(np_array, dtype):
def compute_style_mel(style_wav, ap, cuda=False):
style_mel = torch.FloatTensor(ap.melspectrogram(
ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
if cuda:
return style_mel.cuda()
return style_mel
def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None):
if 'tacotron' in CONFIG.model.lower():
if "tacotron" in CONFIG.model.lower():
if CONFIG.use_gst:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
)
else:
if truncated:
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
)
else:
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
elif 'glow' in CONFIG.model.lower():
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
)
elif "glow" in CONFIG.model.lower():
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
if hasattr(model, 'module'):
if hasattr(model, "module"):
# distributed model
postnet_output, _, _, _, alignments, _, _ = model.module.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings)
postnet_output, _, _, _, alignments, _, _ = model.module.inference(
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
)
else:
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings)
postnet_output, _, _, _, alignments, _, _ = model.inference(
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
)
postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models.
decoder_output = None
stop_tokens = None
elif CONFIG.model.lower() in ['speedy_speech', 'align_tts']:
elif CONFIG.model.lower() in ["speedy_speech", "align_tts"]:
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
if hasattr(model, 'module'):
if hasattr(model, "module"):
# distributed model
postnet_output, alignments = model.module.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings)
postnet_output, alignments = model.module.inference(
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
)
else:
postnet_output, alignments = model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings)
postnet_output, alignments = model.inference(
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
)
postnet_output = postnet_output.permute(0, 2, 1)
# these only belong to tacotron models.
decoder_output = None
stop_tokens = None
else:
raise ValueError('[!] Unknown model name.')
raise ValueError("[!] Unknown model name.")
return decoder_output, postnet_output, alignments, stop_tokens
def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
if CONFIG.use_gst and style_mel is not None:
raise NotImplementedError(' [!] GST inference not implemented for TF')
raise NotImplementedError(" [!] GST inference not implemented for TF")
if truncated:
raise NotImplementedError(' [!] Truncated inference not implemented for TF')
raise NotImplementedError(" [!] Truncated inference not implemented for TF")
if speaker_id is not None:
raise NotImplementedError(' [!] Multi-Speaker not implemented for TF')
raise NotImplementedError(" [!] Multi-Speaker not implemented for TF")
# TODO: handle multispeaker case
decoder_output, postnet_output, alignments, stop_tokens = model(
inputs, training=False)
decoder_output, postnet_output, alignments, stop_tokens = model(inputs, training=False)
return decoder_output, postnet_output, alignments, stop_tokens
def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
if CONFIG.use_gst and style_mel is not None:
raise NotImplementedError(' [!] GST inference not implemented for TfLite')
raise NotImplementedError(" [!] GST inference not implemented for TfLite")
if truncated:
raise NotImplementedError(' [!] Truncated inference not implemented for TfLite')
raise NotImplementedError(" [!] Truncated inference not implemented for TfLite")
if speaker_id is not None:
raise NotImplementedError(' [!] Multi-Speaker not implemented for TfLite')
raise NotImplementedError(" [!] Multi-Speaker not implemented for TfLite")
# get input and output details
input_details = model.get_input_details()
output_details = model.get_output_details()
# reshape input tensor for the new input shape
model.resize_tensor_input(input_details[0]['index'], inputs.shape)
model.resize_tensor_input(input_details[0]["index"], inputs.shape)
model.allocate_tensors()
detail = input_details[0]
# input_shape = detail['shape']
model.set_tensor(detail['index'], inputs)
model.set_tensor(detail["index"], inputs)
# run the model
model.invoke()
# collect outputs
decoder_output = model.get_tensor(output_details[0]['index'])
postnet_output = model.get_tensor(output_details[1]['index'])
decoder_output = model.get_tensor(output_details[0]["index"])
postnet_output = model.get_tensor(output_details[1]["index"])
# tflite model only returns feature frames
return decoder_output, postnet_output, None, None
@ -154,7 +172,7 @@ def parse_outputs_tflite(postnet_output, decoder_output):
def trim_silence(wav, ap):
return wav[:ap.find_endpoint(wav)]
return wav[: ap.find_endpoint(wav)]
def inv_spectrogram(postnet_output, ap, CONFIG):
@ -186,13 +204,13 @@ def embedding_to_torch(speaker_embedding, cuda=False):
# TODO: perform GL with pytorch for batching
def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
'''Apply griffin-lim to each sample iterating throught the first dimension.
"""Apply griffin-lim to each sample iterating throught the first dimension.
Args:
inputs (Tensor or np.Array): Features to be converted by GL. First dimension is the batch size.
input_lens (Tensor or np.Array): 1D array of sample lengths.
CONFIG (Dict): TTS config.
ap (AudioProcessor): TTS audio processor.
'''
"""
wavs = []
for idx, spec in enumerate(inputs):
wav_len = (input_lens[idx] * ap.hop_length) - ap.hop_length # inverse librosa padding
@ -202,39 +220,41 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
return wavs
def synthesis(model,
text,
CONFIG,
use_cuda,
ap,
speaker_id=None,
style_wav=None,
truncated=False,
enable_eos_bos_chars=False, #pylint: disable=unused-argument
use_griffin_lim=False,
do_trim_silence=False,
speaker_embedding=None,
backend='torch'):
def synthesis(
model,
text,
CONFIG,
use_cuda,
ap,
speaker_id=None,
style_wav=None,
truncated=False,
enable_eos_bos_chars=False, # pylint: disable=unused-argument
use_griffin_lim=False,
do_trim_silence=False,
speaker_embedding=None,
backend="torch",
):
"""Synthesize voice for the given text.
Args:
model (TTS.tts.models): model to synthesize.
text (str): target text
CONFIG (dict): config dictionary to be loaded from config.json.
use_cuda (bool): enable cuda.
ap (TTS.tts.utils.audio.AudioProcessor): audio processor to process
model outputs.
speaker_id (int): id of speaker
style_wav (str | Dict[str, float]): Uses for style embedding of GST.
truncated (bool): keep model states after inference. It can be used
for continuous inference at long texts.
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
do_trim_silence (bool): trim silence after synthesis.
backend (str): tf or torch
Args:
model (TTS.tts.models): model to synthesize.
text (str): target text
CONFIG (dict): config dictionary to be loaded from config.json.
use_cuda (bool): enable cuda.
ap (TTS.tts.utils.audio.AudioProcessor): audio processor to process
model outputs.
speaker_id (int): id of speaker
style_wav (str | Dict[str, float]): Uses for style embedding of GST.
truncated (bool): keep model states after inference. It can be used
for continuous inference at long texts.
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
do_trim_silence (bool): trim silence after synthesis.
backend (str): tf or torch
"""
# GST processing
style_mel = None
if 'use_gst' in CONFIG.keys() and CONFIG.use_gst and style_wav is not None:
if "use_gst" in CONFIG.keys() and CONFIG.use_gst and style_wav is not None:
if isinstance(style_wav, dict):
style_mel = style_wav
else:
@ -242,7 +262,7 @@ def synthesis(model,
# preprocess the given text
inputs = text_to_seqvec(text, CONFIG)
# pass tensors to backend
if backend == 'torch':
if backend == "torch":
if speaker_id is not None:
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
@ -253,31 +273,35 @@ def synthesis(model,
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda)
inputs = inputs.unsqueeze(0)
elif backend == 'tf':
elif backend == "tf":
# TODO: handle speaker id for tf model
style_mel = numpy_to_tf(style_mel, tf.float32)
inputs = numpy_to_tf(inputs, tf.int32)
inputs = tf.expand_dims(inputs, 0)
elif backend == 'tflite':
elif backend == "tflite":
style_mel = numpy_to_tf(style_mel, tf.float32)
inputs = numpy_to_tf(inputs, tf.int32)
inputs = tf.expand_dims(inputs, 0)
# synthesize voice
if backend == 'torch':
if backend == "torch":
decoder_output, postnet_output, alignments, stop_tokens = run_model_torch(
model, inputs, CONFIG, truncated, speaker_id, style_mel, speaker_embeddings=speaker_embedding)
model, inputs, CONFIG, truncated, speaker_id, style_mel, speaker_embeddings=speaker_embedding
)
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch(
postnet_output, decoder_output, alignments, stop_tokens)
elif backend == 'tf':
postnet_output, decoder_output, alignments, stop_tokens
)
elif backend == "tf":
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
model, inputs, CONFIG, truncated, speaker_id, style_mel)
model, inputs, CONFIG, truncated, speaker_id, style_mel
)
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf(
postnet_output, decoder_output, alignments, stop_tokens)
elif backend == 'tflite':
postnet_output, decoder_output, alignments, stop_tokens
)
elif backend == "tflite":
decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite(
model, inputs, CONFIG, truncated, speaker_id, style_mel)
postnet_output, decoder_output = parse_outputs_tflite(
postnet_output, decoder_output)
model, inputs, CONFIG, truncated, speaker_id, style_mel
)
postnet_output, decoder_output = parse_outputs_tflite(postnet_output, decoder_output)
# convert outputs to numpy
# plot results
wav = None

View File

@ -6,8 +6,7 @@ import phonemizer
from packaging import version
from phonemizer.phonemize import phonemize
from TTS.tts.utils.text import cleaners
from TTS.tts.utils.text.symbols import (_bos, _eos, _punctuations,
make_symbols, phonemes, symbols)
from TTS.tts.utils.text.symbols import _bos, _eos, _punctuations, make_symbols, phonemes, symbols
from TTS.tts.utils.chinese_mandarin.phonemizer import chinese_text_to_phonemes
@ -22,14 +21,14 @@ _id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
_symbols = symbols
_phonemes = phonemes
# Regular expression matching text enclosed in curly braces:
_CURLY_RE = re.compile(r'(.*?)\{(.+?)\}(.*)')
_CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)")
# Regular expression matching punctuations, ignoring empty space
PHONEME_PUNCTUATION_PATTERN = r'['+_punctuations.replace(' ', '')+']+'
PHONEME_PUNCTUATION_PATTERN = r"[" + _punctuations.replace(" ", "") + "]+"
def text2phone(text, language):
'''Convert graphemes to phonemes. For most of the languages, it calls
"""Convert graphemes to phonemes. For most of the languages, it calls
the phonemizer python library that calls espeak/espeak-ng. For chinese
mandarin, it calls pypinyin + custom function for phonemizing
Parameters:
@ -38,60 +37,73 @@ def text2phone(text, language):
Returns:
ph (str): phonemes as a string seperated by "|"
ph = "ɪ|g|ˈ|z|æ|m|p|ə|l"
'''
"""
# TO REVIEW : How to have a good implementation for this?
if language == "zh-CN":
ph = chinese_text_to_phonemes(text)
return ph
seperator = phonemizer.separator.Separator(' |', '', '|')
#try:
seperator = phonemizer.separator.Separator(" |", "", "|")
# try:
punctuations = re.findall(PHONEME_PUNCTUATION_PATTERN, text)
if version.parse(phonemizer.__version__) < version.parse('2.1'):
ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language)
ph = ph[:-1].strip() # skip the last empty character
if version.parse(phonemizer.__version__) < version.parse("2.1"):
ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend="espeak", language=language)
ph = ph[:-1].strip() # skip the last empty character
# phonemizer does not tackle punctuations. Here we do.
# Replace \n with matching punctuations.
if punctuations:
# if text ends with a punctuation.
if text[-1] == punctuations[-1]:
for punct in punctuations[:-1]:
ph = ph.replace('| |\n', '|'+punct+'| |', 1)
ph = ph.replace("| |\n", "|" + punct + "| |", 1)
ph = ph + punctuations[-1]
else:
for punct in punctuations:
ph = ph.replace('| |\n', '|'+punct+'| |', 1)
elif version.parse(phonemizer.__version__) >= version.parse('2.1'):
ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language, preserve_punctuation=True, language_switch='remove-flags')
ph = ph.replace("| |\n", "|" + punct + "| |", 1)
elif version.parse(phonemizer.__version__) >= version.parse("2.1"):
ph = phonemize(
text,
separator=seperator,
strip=False,
njobs=1,
backend="espeak",
language=language,
preserve_punctuation=True,
language_switch="remove-flags",
)
# this is a simple fix for phonemizer.
# https://github.com/bootphon/phonemizer/issues/32
if punctuations:
for punctuation in punctuations:
ph = ph.replace(f"| |{punctuation} ", f"|{punctuation}| |").replace(f"| |{punctuation}", f"|{punctuation}| |")
ph = ph.replace(f"| |{punctuation} ", f"|{punctuation}| |").replace(
f"| |{punctuation}", f"|{punctuation}| |"
)
ph = ph[:-3]
else:
raise RuntimeError(" [!] Use 'phonemizer' version 2.1 or older.")
return ph
def intersperse(sequence, token):
result = [token] * (len(sequence) * 2 + 1)
result[1::2] = sequence
return result
def pad_with_eos_bos(phoneme_sequence, tp=None):
# pylint: disable=global-statement
global _phonemes_to_id, _bos, _eos
if tp:
_bos = tp['bos']
_eos = tp['eos']
_bos = tp["bos"]
_eos = tp["eos"]
_, _phonemes = make_symbols(**tp)
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
return [_phonemes_to_id[_bos]] + list(phoneme_sequence) + [_phonemes_to_id[_eos]]
def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=None, add_blank=False):
# pylint: disable=global-statement
global _phonemes_to_id, _phonemes
@ -105,23 +117,23 @@ def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=
if to_phonemes is None:
print("!! After phoneme conversion the result is None. -- {} ".format(clean_text))
# iterate by skipping empty strings - NOTE: might be useful to keep it to have a better intonation.
for phoneme in filter(None, to_phonemes.split('|')):
for phoneme in filter(None, to_phonemes.split("|")):
sequence += _phoneme_to_sequence(phoneme)
# Append EOS char
if enable_eos_bos:
sequence = pad_with_eos_bos(sequence, tp=tp)
if add_blank:
sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes)
sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes)
return sequence
def sequence_to_phoneme(sequence, tp=None, add_blank=False):
# pylint: disable=global-statement
'''Converts a sequence of IDs back to a string'''
"""Converts a sequence of IDs back to a string"""
global _id_to_phonemes, _phonemes
if add_blank:
sequence = list(filter(lambda x: x != len(_phonemes), sequence))
result = ''
result = ""
if tp:
_, _phonemes = make_symbols(**tp)
_id_to_phonemes = {i: s for i, s in enumerate(_phonemes)}
@ -130,22 +142,22 @@ def sequence_to_phoneme(sequence, tp=None, add_blank=False):
if symbol_id in _id_to_phonemes:
s = _id_to_phonemes[symbol_id]
result += s
return result.replace('}{', ' ')
return result.replace("}{", " ")
def text_to_sequence(text, cleaner_names, tp=None, add_blank=False):
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
Args:
text: string to convert to a sequence
cleaner_names: names of the cleaner functions to run the text through
Returns:
List of integers corresponding to the symbols in the text
'''
Returns:
List of integers corresponding to the symbols in the text
"""
# pylint: disable=global-statement
global _symbol_to_id, _symbols
if tp:
@ -159,18 +171,17 @@ def text_to_sequence(text, cleaner_names, tp=None, add_blank=False):
if not m:
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
break
sequence += _symbols_to_sequence(
_clean_text(m.group(1), cleaner_names))
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
sequence += _arpabet_to_sequence(m.group(2))
text = m.group(3)
if add_blank:
sequence = intersperse(sequence, len(_symbols)) # add a blank token (new), whose id number is len(_symbols)
sequence = intersperse(sequence, len(_symbols)) # add a blank token (new), whose id number is len(_symbols)
return sequence
def sequence_to_text(sequence, tp=None, add_blank=False):
'''Converts a sequence of IDs back to a string'''
"""Converts a sequence of IDs back to a string"""
# pylint: disable=global-statement
global _id_to_symbol, _symbols
if add_blank:
@ -180,22 +191,22 @@ def sequence_to_text(sequence, tp=None, add_blank=False):
_symbols, _ = make_symbols(**tp)
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
result = ''
result = ""
for symbol_id in sequence:
if symbol_id in _id_to_symbol:
s = _id_to_symbol[symbol_id]
# Enclose ARPAbet back in curly braces:
if len(s) > 1 and s[0] == '@':
s = '{%s}' % s[1:]
if len(s) > 1 and s[0] == "@":
s = "{%s}" % s[1:]
result += s
return result.replace('}{', ' ')
return result.replace("}{", " ")
def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception('Unknown cleaner: %s' % name)
raise Exception("Unknown cleaner: %s" % name)
text = cleaner(text)
return text
@ -209,12 +220,12 @@ def _phoneme_to_sequence(phons):
def _arpabet_to_sequence(text):
return _symbols_to_sequence(['@' + s for s in text.split()])
return _symbols_to_sequence(["@" + s for s in text.split()])
def _should_keep_symbol(s):
return s in _symbol_to_id and s not in ['~', '^', '_']
return s in _symbol_to_id and s not in ["~", "^", "_"]
def _should_keep_phoneme(p):
return p in _phonemes_to_id and p not in ['~', '^', '_']
return p in _phonemes_to_id and p not in ["~", "^", "_"]

View File

@ -1,66 +1,73 @@
import re
# List of (regular expression, replacement) pairs for abbreviations in english:
abbreviations_en = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
abbreviations_en = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("mrs", "misess"),
("mr", "mister"),
("dr", "doctor"),
("st", "saint"),
("co", "company"),
("jr", "junior"),
("maj", "major"),
("gen", "general"),
("drs", "doctors"),
("rev", "reverend"),
("lt", "lieutenant"),
("hon", "honorable"),
("sgt", "sergeant"),
("capt", "captain"),
("esq", "esquire"),
("ltd", "limited"),
("col", "colonel"),
("ft", "fort"),
]
]
# List of (regular expression, replacement) pairs for abbreviations in french:
abbreviations_fr = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
for x in [
('M', 'monsieur'),
('Mlle', 'mademoiselle'),
('Mlles', 'mesdemoiselles'),
('Mme', 'Madame'),
('Mmes', 'Mesdames'),
('N.B', 'nota bene'),
('M', 'monsieur'),
('p.c.q', 'parce que'),
('Pr', 'professeur'),
('qqch', 'quelque chose'),
('rdv', 'rendez-vous'),
('max', 'maximum'),
('min', 'minimum'),
('no', 'numéro'),
('adr', 'adresse'),
('dr', 'docteur'),
('st', 'saint'),
('co', 'companie'),
('jr', 'junior'),
('sgt', 'sergent'),
('capt', 'capitain'),
('col', 'colonel'),
('av', 'avenue'),
('av. J.-C', 'avant Jésus-Christ'),
('apr. J.-C', 'après Jésus-Christ'),
('art', 'article'),
('boul', 'boulevard'),
('c.-à-d', 'cest-à-dire'),
('etc', 'et cetera'),
('ex', 'exemple'),
('excl', 'exclusivement'),
('boul', 'boulevard'),
]] + [(re.compile('\\b%s' % x[0]), x[1]) for x in [
('Mlle', 'mademoiselle'),
('Mlles', 'mesdemoiselles'),
('Mme', 'Madame'),
('Mmes', 'Mesdames'),
]]
abbreviations_fr = [
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
for x in [
("M", "monsieur"),
("Mlle", "mademoiselle"),
("Mlles", "mesdemoiselles"),
("Mme", "Madame"),
("Mmes", "Mesdames"),
("N.B", "nota bene"),
("M", "monsieur"),
("p.c.q", "parce que"),
("Pr", "professeur"),
("qqch", "quelque chose"),
("rdv", "rendez-vous"),
("max", "maximum"),
("min", "minimum"),
("no", "numéro"),
("adr", "adresse"),
("dr", "docteur"),
("st", "saint"),
("co", "companie"),
("jr", "junior"),
("sgt", "sergent"),
("capt", "capitain"),
("col", "colonel"),
("av", "avenue"),
("av. J.-C", "avant Jésus-Christ"),
("apr. J.-C", "après Jésus-Christ"),
("art", "article"),
("boul", "boulevard"),
("c.-à-d", "cest-à-dire"),
("etc", "et cetera"),
("ex", "exemple"),
("excl", "exclusivement"),
("boul", "boulevard"),
]
] + [
(re.compile("\\b%s" % x[0]), x[1])
for x in [
("Mlle", "mademoiselle"),
("Mlles", "mesdemoiselles"),
("Mme", "Madame"),
("Mmes", "Mesdames"),
]
]

View File

@ -1,4 +1,4 @@
'''
"""
Cleaners are transformations that run over the input text at both training and eval time.
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
@ -8,7 +8,7 @@ hyperparameter. Some cleaners are English-specific. You'll typically want to use
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
the symbols in symbols.py to match your data).
'''
"""
import re
from unidecode import unidecode
@ -19,13 +19,13 @@ from TTS.tts.utils.chinese_mandarin.numbers import replace_numbers_to_characters
# Regular expression matching whitespace:
_whitespace_re = re.compile(r'\s+')
_whitespace_re = re.compile(r"\s+")
def expand_abbreviations(text, lang='en'):
if lang == 'en':
def expand_abbreviations(text, lang="en"):
if lang == "en":
_abbreviations = abbreviations_en
elif lang == 'fr':
elif lang == "fr":
_abbreviations = abbreviations_fr
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
@ -41,7 +41,7 @@ def lowercase(text):
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text).strip()
return re.sub(_whitespace_re, " ", text).strip()
def convert_to_ascii(text):
@ -49,30 +49,32 @@ def convert_to_ascii(text):
def remove_aux_symbols(text):
text = re.sub(r'[\<\>\(\)\[\]\"]+', '', text)
text = re.sub(r"[\<\>\(\)\[\]\"]+", "", text)
return text
def replace_symbols(text, lang='en'):
text = text.replace(';', ',')
text = text.replace('-', ' ')
text = text.replace(':', ',')
if lang == 'en':
text = text.replace('&', ' and ')
elif lang == 'fr':
text = text.replace('&', ' et ')
elif lang == 'pt':
text = text.replace('&', ' e ')
def replace_symbols(text, lang="en"):
text = text.replace(";", ",")
text = text.replace("-", " ")
text = text.replace(":", ",")
if lang == "en":
text = text.replace("&", " and ")
elif lang == "fr":
text = text.replace("&", " et ")
elif lang == "pt":
text = text.replace("&", " e ")
return text
def basic_cleaners(text):
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_cleaners(text):
'''Pipeline for non-English text that transliterates to ASCII.'''
"""Pipeline for non-English text that transliterates to ASCII."""
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
@ -80,7 +82,7 @@ def transliteration_cleaners(text):
def basic_german_cleaners(text):
'''Pipeline for German text'''
"""Pipeline for German text"""
text = lowercase(text)
text = collapse_whitespace(text)
return text
@ -88,7 +90,7 @@ def basic_german_cleaners(text):
# TODO: elaborate it
def basic_turkish_cleaners(text):
'''Pipeline for Turkish text'''
"""Pipeline for Turkish text"""
text = text.replace("I", "ı")
text = lowercase(text)
text = collapse_whitespace(text)
@ -96,7 +98,7 @@ def basic_turkish_cleaners(text):
def english_cleaners(text):
'''Pipeline for English text, including number and abbreviation expansion.'''
"""Pipeline for English text, including number and abbreviation expansion."""
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_time_english(text)
@ -109,33 +111,33 @@ def english_cleaners(text):
def french_cleaners(text):
'''Pipeline for French text. There is no need to expand numbers, phonemizer already does that'''
text = expand_abbreviations(text, lang='fr')
"""Pipeline for French text. There is no need to expand numbers, phonemizer already does that"""
text = expand_abbreviations(text, lang="fr")
text = lowercase(text)
text = replace_symbols(text, lang='fr')
text = replace_symbols(text, lang="fr")
text = remove_aux_symbols(text)
text = collapse_whitespace(text)
return text
def portuguese_cleaners(text):
'''Basic pipeline for Portuguese text. There is no need to expand abbreviation and
numbers, phonemizer already does that'''
"""Basic pipeline for Portuguese text. There is no need to expand abbreviation and
numbers, phonemizer already does that"""
text = lowercase(text)
text = replace_symbols(text, lang='pt')
text = replace_symbols(text, lang="pt")
text = remove_aux_symbols(text)
text = collapse_whitespace(text)
return text
def chinese_mandarin_cleaners(text: str) -> str:
'''Basic pipeline for chinese'''
"""Basic pipeline for chinese"""
text = replace_numbers_to_characters_in_text(text)
return text
def phoneme_cleaners(text):
'''Pipeline for phonemes mode, including number and abbreviation expansion.'''
"""Pipeline for phonemes mode, including number and abbreviation expansion."""
text = expand_numbers(text)
text = convert_to_ascii(text)
text = expand_abbreviations(text)

View File

@ -3,43 +3,116 @@
import re
VALID_SYMBOLS = [
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1',
'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0',
'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0',
'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0',
'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG',
'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH',
'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W',
'Y', 'Z', 'ZH'
"AA",
"AA0",
"AA1",
"AA2",
"AE",
"AE0",
"AE1",
"AE2",
"AH",
"AH0",
"AH1",
"AH2",
"AO",
"AO0",
"AO1",
"AO2",
"AW",
"AW0",
"AW1",
"AW2",
"AY",
"AY0",
"AY1",
"AY2",
"B",
"CH",
"D",
"DH",
"EH",
"EH0",
"EH1",
"EH2",
"ER",
"ER0",
"ER1",
"ER2",
"EY",
"EY0",
"EY1",
"EY2",
"F",
"G",
"HH",
"IH",
"IH0",
"IH1",
"IH2",
"IY",
"IY0",
"IY1",
"IY2",
"JH",
"K",
"L",
"M",
"N",
"NG",
"OW",
"OW0",
"OW1",
"OW2",
"OY",
"OY0",
"OY1",
"OY2",
"P",
"R",
"S",
"SH",
"T",
"TH",
"UH",
"UH0",
"UH1",
"UH2",
"UW",
"UW0",
"UW1",
"UW2",
"V",
"W",
"Y",
"Z",
"ZH",
]
class CMUDict:
'''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
"""Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
def __init__(self, file_or_path, keep_ambiguous=True):
if isinstance(file_or_path, str):
with open(file_or_path, encoding='latin-1') as f:
with open(file_or_path, encoding="latin-1") as f:
entries = _parse_cmudict(f)
else:
entries = _parse_cmudict(file_or_path)
if not keep_ambiguous:
entries = {
word: pron
for word, pron in entries.items() if len(pron) == 1
}
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
self._entries = entries
def __len__(self):
return len(self._entries)
def lookup(self, word):
'''Returns list of ARPAbet pronunciations of the given word.'''
"""Returns list of ARPAbet pronunciations of the given word."""
return self._entries.get(word.upper())
@staticmethod
def get_arpabet(word, cmudict, punctuation_symbols):
first_symbol, last_symbol = '', ''
first_symbol, last_symbol = "", ""
if word and word[0] in punctuation_symbols:
first_symbol = word[0]
word = word[1:]
@ -48,19 +121,19 @@ class CMUDict:
word = word[:-1]
arpabet = cmudict.lookup(word)
if arpabet is not None:
return first_symbol + '{%s}' % arpabet[0] + last_symbol
return first_symbol + "{%s}" % arpabet[0] + last_symbol
return first_symbol + word + last_symbol
_alt_re = re.compile(r'\([0-9]+\)')
_alt_re = re.compile(r"\([0-9]+\)")
def _parse_cmudict(file):
cmudict = {}
for line in file:
if line and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
parts = line.split(' ')
word = re.sub(_alt_re, '', parts[0])
if line and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
parts = line.split(" ")
word = re.sub(_alt_re, "", parts[0])
pronunciation = _get_pronunciation(parts[1])
if pronunciation:
if word in cmudict:
@ -71,8 +144,8 @@ def _parse_cmudict(file):
def _get_pronunciation(s):
parts = s.strip().split(' ')
parts = s.strip().split(" ")
for part in parts:
if part not in VALID_SYMBOLS:
return None
return ' '.join(parts)
return " ".join(parts)

View File

@ -5,23 +5,23 @@ import re
from typing import Dict
_inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_currency_re = re.compile(r'(£|\$|¥)([0-9\,\.]*[0-9]+)')
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
_number_re = re.compile(r'-?[0-9]+')
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
_currency_re = re.compile(r"(£|\$|¥)([0-9\,\.]*[0-9]+)")
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
_number_re = re.compile(r"-?[0-9]+")
def _remove_commas(m):
return m.group(1).replace(',', '')
return m.group(1).replace(",", "")
def _expand_decimal_point(m):
return m.group(1).replace('.', ' point ')
return m.group(1).replace(".", " point ")
def __expand_currency(value: str, inflection: Dict[float, str]) -> str:
parts = value.replace(",", "").split('.')
parts = value.replace(",", "").split(".")
if len(parts) > 2:
return f"{value} {inflection[2]}" # Unexpected format
text = []
@ -31,7 +31,7 @@ def __expand_currency(value: str, inflection: Dict[float, str]) -> str:
text.append(f"{integer} {integer_unit}")
fraction = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if fraction > 0:
fraction_unit = inflection.get(fraction/100, inflection[0.02])
fraction_unit = inflection.get(fraction / 100, inflection[0.02])
text.append(f"{fraction} {fraction_unit}")
if len(text) == 0:
return f"zero {inflection[2]}"
@ -62,7 +62,7 @@ def _expand_currency(m: "re.Match") -> str:
# TODO rin
0.02: "sen",
2: "yen",
}
},
}
unit = m.group(1)
currency = currencies[unit]
@ -78,16 +78,13 @@ def _expand_number(m):
num = int(m.group(0))
if 1000 < num < 3000:
if num == 2000:
return 'two thousand'
return "two thousand"
if 2000 < num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100)
return "two thousand " + _inflect.number_to_words(num % 100)
if num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred'
return _inflect.number_to_words(num,
andword='',
zero='oh',
group=2).replace(', ', ' ')
return _inflect.number_to_words(num, andword='')
return _inflect.number_to_words(num // 100) + " hundred"
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
return _inflect.number_to_words(num, andword="")
def normalize_numbers(text):

View File

@ -1,38 +1,41 @@
# -*- coding: utf-8 -*-
'''
"""
Defines the set of symbols used in text input to the model.
The default is a set of ASCII characters that works well for English or text that has been run
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
'''
"""
def make_symbols(characters, phonemes=None, punctuations='!\'(),-.:;? ', pad='_', eos='~', bos='^'):# pylint: disable=redefined-outer-name
''' Function to create symbols and phonemes '''
def make_symbols(
characters, phonemes=None, punctuations="!'(),-.:;? ", pad="_", eos="~", bos="^"
): # pylint: disable=redefined-outer-name
""" Function to create symbols and phonemes """
_symbols = [pad, eos, bos] + list(characters)
_phonemes = None
if phonemes is not None:
_phonemes_sorted = sorted(list(set(phonemes)))
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
_arpabet = ['@' + s for s in _phonemes_sorted]
_arpabet = ["@" + s for s in _phonemes_sorted]
# Export all symbols:
_phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations)
_symbols += _arpabet
return _symbols, _phonemes
_pad = '_'
_eos = '~'
_bos = '^'
_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? '
_punctuations = '!\'(),-.:;? '
_pad = "_"
_eos = "~"
_bos = "^"
_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? "
_punctuations = "!'(),-.:;? "
# Phonemes definition (All IPA characters)
_vowels = 'iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ'
_non_pulmonic_consonants = 'ʘɓǀɗǃʄǂɠǁʛ'
_pulmonic_consonants = 'pbtdʈɖcɟkɡʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ'
_suprasegmentals = 'ˈˌːˑ'
_other_symbols = 'ʍwɥʜʢʡɕʑɺɧʲ'
_diacrilics = 'ɚ˞ɫ'
_vowels = "iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ"
_non_pulmonic_consonants = "ʘɓǀɗǃʄǂɠǁʛ"
_pulmonic_consonants = "pbtdʈɖcɟkɡʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ"
_suprasegmentals = "ˈˌːˑ"
_other_symbols = "ʍwɥʜʢʡɕʑɺɧʲ"
_diacrilics = "ɚ˞ɫ"
_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics
symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _eos, _bos)
@ -43,16 +46,18 @@ symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _e
def parse_symbols():
return {'pad': _pad,
'eos': _eos,
'bos': _bos,
'characters': _characters,
'punctuations': _punctuations,
'phonemes': _phonemes}
return {
"pad": _pad,
"eos": _eos,
"bos": _bos,
"characters": _characters,
"punctuations": _punctuations,
"phonemes": _phonemes,
}
if __name__ == '__main__':
if __name__ == "__main__":
print(" > TTS symbols {}".format(len(symbols)))
print(symbols)
print(" > TTS phonemes {}".format(len(phonemes)))
print(''.join(sorted(phonemes)))
print("".join(sorted(phonemes)))

View File

@ -3,13 +3,15 @@ import inflect
_inflect = inflect.engine()
_time_re = re.compile(r"""\b
_time_re = re.compile(
r"""\b
((0?[0-9])|(1[0-1])|(1[2-9])|(2[0-3])) # hours
:
([0-5][0-9]) # minutes
\s*(a\\.m\\.|am|pm|p\\.m\\.|a\\.m|p\\.m)? # am/pm
\b""",
re.IGNORECASE | re.X)
re.IGNORECASE | re.X,
)
def _expand_num(n: int) -> str:

View File

@ -3,33 +3,25 @@ import matplotlib
import numpy as np
import torch
matplotlib.use('Agg')
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme
def plot_alignment(alignment,
info=None,
fig_size=(16, 10),
title=None,
output_fig=False):
def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_fig=False):
if isinstance(alignment, torch.Tensor):
alignment_ = alignment.detach().cpu().numpy().squeeze()
else:
alignment_ = alignment
alignment_ = alignment_.astype(
np.float32) if alignment_.dtype == np.float16 else alignment_
alignment_ = alignment_.astype(np.float32) if alignment_.dtype == np.float16 else alignment_
fig, ax = plt.subplots(figsize=fig_size)
im = ax.imshow(alignment_.T,
aspect='auto',
origin='lower',
interpolation='none')
im = ax.imshow(alignment_.T, aspect="auto", origin="lower", interpolation="none")
fig.colorbar(im, ax=ax)
xlabel = 'Decoder timestep'
xlabel = "Decoder timestep"
if info is not None:
xlabel += '\n\n' + info
xlabel += "\n\n" + info
plt.xlabel(xlabel)
plt.ylabel('Encoder timestep')
plt.ylabel("Encoder timestep")
# plt.yticks(range(len(text)), list(text))
plt.tight_layout()
if title is not None:
@ -39,16 +31,12 @@ def plot_alignment(alignment,
return fig
def plot_spectrogram(spectrogram,
ap=None,
fig_size=(16, 10),
output_fig=False):
def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False):
if isinstance(spectrogram, torch.Tensor):
spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T
else:
spectrogram_ = spectrogram.T
spectrogram_ = spectrogram_.astype(
np.float32) if spectrogram_.dtype == np.float16 else spectrogram_
spectrogram_ = spectrogram_.astype(np.float32) if spectrogram_.dtype == np.float16 else spectrogram_
if ap is not None:
spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access
fig = plt.figure(figsize=fig_size)
@ -60,16 +48,18 @@ def plot_spectrogram(spectrogram,
return fig
def visualize(alignment,
postnet_output,
text,
hop_length,
CONFIG,
stop_tokens=None,
decoder_output=None,
output_path=None,
figsize=(8, 24),
output_fig=False):
def visualize(
alignment,
postnet_output,
text,
hop_length,
CONFIG,
stop_tokens=None,
decoder_output=None,
output_path=None,
figsize=(8, 24),
output_fig=False,
):
if decoder_output is not None:
num_plot = 4
@ -86,13 +76,13 @@ def visualize(alignment,
# compute phoneme representation and back
if CONFIG.use_phonemes:
seq = phoneme_to_sequence(
text, [CONFIG.text_cleaner],
text,
[CONFIG.text_cleaner],
CONFIG.phoneme_language,
CONFIG.enable_eos_bos_chars,
tp=CONFIG.characters if 'characters' in CONFIG.keys() else None)
text = sequence_to_phoneme(
seq,
tp=CONFIG.characters if 'characters' in CONFIG.keys() else None)
tp=CONFIG.characters if "characters" in CONFIG.keys() else None,
)
text = sequence_to_phoneme(seq, tp=CONFIG.characters if "characters" in CONFIG.keys() else None)
print(text)
plt.yticks(range(len(text)), list(text))
plt.colorbar()
@ -104,13 +94,15 @@ def visualize(alignment,
# plot postnet spectrogram
plt.subplot(num_plot, 1, 3)
librosa.display.specshow(postnet_output.T,
sr=CONFIG.audio['sample_rate'],
hop_length=hop_length,
x_axis="time",
y_axis="linear",
fmin=CONFIG.audio['mel_fmin'],
fmax=CONFIG.audio['mel_fmax'])
librosa.display.specshow(
postnet_output.T,
sr=CONFIG.audio["sample_rate"],
hop_length=hop_length,
x_axis="time",
y_axis="linear",
fmin=CONFIG.audio["mel_fmin"],
fmax=CONFIG.audio["mel_fmax"],
)
plt.xlabel("Time", fontsize=label_fontsize)
plt.ylabel("Hz", fontsize=label_fontsize)
@ -119,13 +111,15 @@ def visualize(alignment,
if decoder_output is not None:
plt.subplot(num_plot, 1, 4)
librosa.display.specshow(decoder_output.T,
sr=CONFIG.audio['sample_rate'],
hop_length=hop_length,
x_axis="time",
y_axis="linear",
fmin=CONFIG.audio['mel_fmin'],
fmax=CONFIG.audio['mel_fmax'])
librosa.display.specshow(
decoder_output.T,
sr=CONFIG.audio["sample_rate"],
hop_length=hop_length,
x_axis="time",
y_axis="linear",
fmin=CONFIG.audio["mel_fmin"],
fmax=CONFIG.audio["mel_fmax"],
)
plt.xlabel("Time", fontsize=label_fontsize)
plt.ylabel("Hz", fontsize=label_fontsize)
plt.tight_layout()

View File

@ -29,41 +29,31 @@ def parse_arguments(argv):
parser.add_argument(
"--continue_path",
type=str,
help=("Training output folder to continue training. Used to continue "
"a training. If it is used, 'config_path' is ignored."),
help=(
"Training output folder to continue training. Used to continue "
"a training. If it is used, 'config_path' is ignored."
),
default="",
required="--config_path" not in argv)
required="--config_path" not in argv,
)
parser.add_argument(
"--restore_path",
type=str,
help="Model file to be restored. Use to finetune a model.",
default="")
"--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default=""
)
parser.add_argument(
"--best_path",
type=str,
help=("Best model file to be used for extracting best loss."
"If not specified, the latest best model in continue path is used"),
default="")
parser.add_argument(
"--config_path",
type=str,
help="Path to config file for training.",
required="--continue_path" not in argv)
parser.add_argument(
"--debug",
type=bool,
default=False,
help="Do not verify commit integrity to run training.")
parser.add_argument(
"--rank",
type=int,
default=0,
help="DISTRIBUTED: process rank for distributed training.")
parser.add_argument(
"--group_id",
type=str,
help=(
"Best model file to be used for extracting best loss."
"If not specified, the latest best model in continue path is used"
),
default="",
help="DISTRIBUTED: process group id.")
)
parser.add_argument(
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in argv
)
parser.add_argument("--debug", type=bool, default=False, help="Do not verify commit integrity to run training.")
parser.add_argument("--rank", type=int, default=0, help="DISTRIBUTED: process rank for distributed training.")
parser.add_argument("--group_id", type=str, default="", help="DISTRIBUTED: process group id.")
return parser.parse_args()
@ -86,7 +76,7 @@ def get_last_checkpoint(path):
file_names = glob.glob(os.path.join(path, "*.pth.tar"))
last_models = {}
last_model_nums = {}
for key in ['checkpoint', 'best_model']:
for key in ["checkpoint", "best_model"]:
last_model_num = None
last_model = None
# pass all the checkpoint files and find
@ -105,7 +95,7 @@ def get_last_checkpoint(path):
key_file_names = [fn for fn in file_names if key in fn]
if last_model is None and len(key_file_names) > 0:
last_model = max(key_file_names, key=os.path.getctime)
last_model_num = torch.load(last_model)['step']
last_model_num = torch.load(last_model)["step"]
if last_model is not None:
last_models[key] = last_model
@ -114,16 +104,16 @@ def get_last_checkpoint(path):
# check what models were found
if not last_models:
raise ValueError(f"No models found in continue path {path}!")
if 'checkpoint' not in last_models: # no checkpoint just best model
last_models['checkpoint'] = last_models['best_model']
elif 'best_model' not in last_models: # no best model
if "checkpoint" not in last_models: # no checkpoint just best model
last_models["checkpoint"] = last_models["best_model"]
elif "best_model" not in last_models: # no best model
# this shouldn't happen, but let's handle it just in case
last_models['best_model'] = None
last_models["best_model"] = None
# finally check if last best model is more recent than checkpoint
elif last_model_nums['best_model'] > last_model_nums['checkpoint']:
last_models['checkpoint'] = last_models['best_model']
elif last_model_nums["best_model"] > last_model_nums["checkpoint"]:
last_models["checkpoint"] = last_models["best_model"]
return last_models['checkpoint'], last_models['best_model']
return last_models["checkpoint"], last_models["best_model"]
def process_args(args, model_class):
@ -157,13 +147,12 @@ def process_args(args, model_class):
c = load_config(args.config_path)
_ = os.path.dirname(os.path.realpath(__file__))
if 'mixed_precision' in c and c.mixed_precision:
if "mixed_precision" in c and c.mixed_precision:
print(" > Mixed precision mode is ON")
out_path = args.continue_path
if not out_path:
out_path = create_experiment_folder(c.output_path, c.run_name,
args.debug)
out_path = create_experiment_folder(c.output_path, c.run_name, args.debug)
audio_path = os.path.join(out_path, "test_audios")
@ -179,11 +168,10 @@ def process_args(args, model_class):
# if model characters are not set in the config file
# save the default set to the config file for future
# compatibility.
if model_class == 'tts' and 'characters' not in c:
if model_class == "tts" and "characters" not in c:
used_characters = parse_symbols()
new_fields['characters'] = used_characters
copy_model_files(c, args.config_path,
out_path, new_fields)
new_fields["characters"] = used_characters
copy_model_files(c, args.config_path, out_path, new_fields)
os.chmod(audio_path, 0o775)
os.chmod(out_path, 0o775)

View File

@ -3,11 +3,12 @@ import soundfile as sf
import numpy as np
import scipy.io.wavfile
import scipy.signal
# import pyworld as pw
from TTS.tts.utils.data import StandardScaler
#pylint: disable=too-many-public-methods
# pylint: disable=too-many-public-methods
class AudioProcessor(object):
"""Audio Processor for TTS used by all the data pipelines.
@ -43,35 +44,38 @@ class AudioProcessor(object):
stats_path (str, optional): Path to the computed stats file. Defaults to None.
verbose (bool, optional): enable/disable logging. Defaults to True.
"""
def __init__(self,
sample_rate=None,
resample=False,
num_mels=None,
log_func='np.log10',
min_level_db=None,
frame_shift_ms=None,
frame_length_ms=None,
hop_length=None,
win_length=None,
ref_level_db=None,
fft_size=1024,
power=None,
preemphasis=0.0,
signal_norm=None,
symmetric_norm=None,
max_norm=None,
mel_fmin=None,
mel_fmax=None,
spec_gain=20,
stft_pad_mode='reflect',
clip_norm=True,
griffin_lim_iters=None,
do_trim_silence=False,
trim_db=60,
do_sound_norm=False,
stats_path=None,
verbose=True,
**_):
def __init__(
self,
sample_rate=None,
resample=False,
num_mels=None,
log_func="np.log10",
min_level_db=None,
frame_shift_ms=None,
frame_length_ms=None,
hop_length=None,
win_length=None,
ref_level_db=None,
fft_size=1024,
power=None,
preemphasis=0.0,
signal_norm=None,
symmetric_norm=None,
max_norm=None,
mel_fmin=None,
mel_fmax=None,
spec_gain=20,
stft_pad_mode="reflect",
clip_norm=True,
griffin_lim_iters=None,
do_trim_silence=False,
trim_db=60,
do_sound_norm=False,
stats_path=None,
verbose=True,
**_,
):
# setup class attributed
self.sample_rate = sample_rate
@ -98,14 +102,14 @@ class AudioProcessor(object):
self.do_sound_norm = do_sound_norm
self.stats_path = stats_path
# setup exp_func for db to amp conversion
print(f'self.log_func = {log_func}')
exec(f'self.log_func = {log_func}') #pylint: disable=exec-used
if self.log_func.__name__ == 'log':
print(f"self.log_func = {log_func}")
exec(f"self.log_func = {log_func}") # pylint: disable=exec-used
if self.log_func.__name__ == "log":
self.exp_func = np.exp
elif self.log_func.__name__ == 'log10':
elif self.log_func.__name__ == "log10":
self.exp_func = lambda x: 10 ** x
else:
raise ValueError(' [!] unknown `log_func` value.')
raise ValueError(" [!] unknown `log_func` value.")
# setup stft parameters
if hop_length is None:
# compute stft parameters from given time values
@ -134,17 +138,18 @@ class AudioProcessor(object):
self.symmetric_norm = None
### setting up the parameters ###
def _build_mel_basis(self, ):
def _build_mel_basis(
self,
):
if self.mel_fmax is not None:
assert self.mel_fmax <= self.sample_rate // 2
return librosa.filters.mel(
self.sample_rate,
self.fft_size,
n_mels=self.num_mels,
fmin=self.mel_fmin,
fmax=self.mel_fmax)
self.sample_rate, self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
)
def _stft_parameters(self, ):
def _stft_parameters(
self,
):
"""Compute necessary stft parameters with given time values"""
factor = self.frame_length_ms / self.frame_shift_ms
assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms"
@ -155,24 +160,26 @@ class AudioProcessor(object):
### normalization ###
def normalize(self, S):
"""Put values in [0, self.max_norm] or [-self.max_norm, self.max_norm]"""
#pylint: disable=no-else-return
# pylint: disable=no-else-return
S = S.copy()
if self.signal_norm:
# mean-var scaling
if hasattr(self, 'mel_scaler'):
if hasattr(self, "mel_scaler"):
if S.shape[0] == self.num_mels:
return self.mel_scaler.transform(S.T).T
elif S.shape[0] == self.fft_size / 2:
return self.linear_scaler.transform(S.T).T
else:
raise RuntimeError(' [!] Mean-Var stats does not match the given feature dimensions.')
raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.")
# range normalization
S -= self.ref_level_db # discard certain range of DB assuming it is air noise
S_norm = ((S - self.min_level_db) / (-self.min_level_db))
S_norm = (S - self.min_level_db) / (-self.min_level_db)
if self.symmetric_norm:
S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm
if self.clip_norm:
S_norm = np.clip(S_norm, -self.max_norm, self.max_norm) # pylint: disable=invalid-unary-operand-type
S_norm = np.clip(
S_norm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type
)
return S_norm
else:
S_norm = self.max_norm * S_norm
@ -184,47 +191,49 @@ class AudioProcessor(object):
def denormalize(self, S):
"""denormalize values"""
#pylint: disable=no-else-return
# pylint: disable=no-else-return
S_denorm = S.copy()
if self.signal_norm:
# mean-var scaling
if hasattr(self, 'mel_scaler'):
if hasattr(self, "mel_scaler"):
if S_denorm.shape[0] == self.num_mels:
return self.mel_scaler.inverse_transform(S_denorm.T).T
elif S_denorm.shape[0] == self.fft_size / 2:
return self.linear_scaler.inverse_transform(S_denorm.T).T
else:
raise RuntimeError(' [!] Mean-Var stats does not match the given feature dimensions.')
raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.")
if self.symmetric_norm:
if self.clip_norm:
S_denorm = np.clip(S_denorm, -self.max_norm, self.max_norm) #pylint: disable=invalid-unary-operand-type
S_denorm = np.clip(
S_denorm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type
)
S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db
return S_denorm + self.ref_level_db
else:
if self.clip_norm:
S_denorm = np.clip(S_denorm, 0, self.max_norm)
S_denorm = (S_denorm * -self.min_level_db /
self.max_norm) + self.min_level_db
S_denorm = (S_denorm * -self.min_level_db / self.max_norm) + self.min_level_db
return S_denorm + self.ref_level_db
else:
return S_denorm
### Mean-STD scaling ###
def load_stats(self, stats_path):
stats = np.load(stats_path, allow_pickle=True).item() #pylint: disable=unexpected-keyword-arg
mel_mean = stats['mel_mean']
mel_std = stats['mel_std']
linear_mean = stats['linear_mean']
linear_std = stats['linear_std']
stats_config = stats['audio_config']
stats = np.load(stats_path, allow_pickle=True).item() # pylint: disable=unexpected-keyword-arg
mel_mean = stats["mel_mean"]
mel_std = stats["mel_std"]
linear_mean = stats["linear_mean"]
linear_std = stats["linear_std"]
stats_config = stats["audio_config"]
# check all audio parameters used for computing stats
skip_parameters = ['griffin_lim_iters', 'stats_path', 'do_trim_silence', 'ref_level_db', 'power']
skip_parameters = ["griffin_lim_iters", "stats_path", "do_trim_silence", "ref_level_db", "power"]
for key in stats_config.keys():
if key in skip_parameters:
continue
if key not in ['sample_rate', 'trim_db']:
assert stats_config[key] == self.__dict__[key],\
f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
if key not in ["sample_rate", "trim_db"]:
assert (
stats_config[key] == self.__dict__[key]
), f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
return mel_mean, mel_std, linear_mean, linear_std, stats_config
# pylint: disable=attribute-defined-outside-init
@ -243,7 +252,6 @@ class AudioProcessor(object):
def _db_to_amp(self, x):
return self.exp_func(x / self.spec_gain)
### Preemphasis ###
def apply_preemphasis(self, x):
if self.preemphasis == 0:
@ -284,17 +292,17 @@ class AudioProcessor(object):
S = self._db_to_amp(S)
# Reconstruct phase
if self.preemphasis != 0:
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
return self._griffin_lim(S**self.power)
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
return self._griffin_lim(S ** self.power)
def inv_melspectrogram(self, mel_spectrogram):
'''Converts melspectrogram to waveform using librosa'''
"""Converts melspectrogram to waveform using librosa"""
D = self.denormalize(mel_spectrogram)
S = self._db_to_amp(D)
S = self._mel_to_linear(S) # Convert back to linear
if self.preemphasis != 0:
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
return self._griffin_lim(S**self.power)
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
return self._griffin_lim(S ** self.power)
def out_linear_to_mel(self, linear_spec):
S = self.denormalize(linear_spec)
@ -315,8 +323,7 @@ class AudioProcessor(object):
)
def _istft(self, y):
return librosa.istft(
y, hop_length=self.hop_length, win_length=self.win_length)
return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length)
def _griffin_lim(self, S):
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
@ -328,8 +335,7 @@ class AudioProcessor(object):
return y
def compute_stft_paddings(self, x, pad_sides=1):
'''compute right padding (final frame) or both sides padding (first and final frames)
'''
"""compute right padding (final frame) or both sides padding (first and final frames)"""
assert pad_sides in (1, 2)
pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0]
if pad_sides == 1:
@ -354,7 +360,7 @@ class AudioProcessor(object):
hop_length = int(window_length / 4)
threshold = self._db_to_amp(threshold_db)
for x in range(hop_length, len(wav) - window_length, hop_length):
if np.max(wav[x:x + window_length]) < threshold:
if np.max(wav[x : x + window_length]) < threshold:
return x + hop_length
return len(wav)
@ -362,8 +368,9 @@ class AudioProcessor(object):
""" Trim silent parts with a threshold and 0.01 sec margin """
margin = int(self.sample_rate * 0.01)
wav = wav[margin:-margin]
return librosa.effects.trim(
wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[0]
return librosa.effects.trim(wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[
0
]
@staticmethod
def sound_norm(x):
@ -375,14 +382,14 @@ class AudioProcessor(object):
x, sr = librosa.load(filename, sr=self.sample_rate)
elif sr is None:
x, sr = sf.read(filename)
assert self.sample_rate == sr, "%s vs %s"%(self.sample_rate, sr)
assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr)
else:
x, sr = librosa.load(filename, sr=sr)
if self.do_trim_silence:
try:
x = self.trim_silence(x)
except ValueError:
print(f' [!] File cannot be trimmed for silence - {filename}')
print(f" [!] File cannot be trimmed for silence - {filename}")
if self.do_sound_norm:
x = self.sound_norm(x)
return x
@ -396,10 +403,12 @@ class AudioProcessor(object):
def mulaw_encode(wav, qc):
mu = 2 ** qc - 1
# wav_abs = np.minimum(np.abs(wav), 1.0)
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1. + mu)
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
# Quantize signal to the specified number of levels.
signal = (signal + 1) / 2 * mu + 0.5
return np.floor(signal,)
return np.floor(
signal,
)
@staticmethod
def mulaw_decode(wav, qc):
@ -408,15 +417,14 @@ class AudioProcessor(object):
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
return x
@staticmethod
def encode_16bits(x):
return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16)
return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16)
@staticmethod
def quantize(x, bits):
return (x + 1.) * (2**bits - 1) / 2
return (x + 1.0) * (2 ** bits - 1) / 2
@staticmethod
def dequantize(x, bits):
return 2 * x / (2**bits - 1) - 1
return 2 * x / (2 ** bits - 1) - 1

View File

@ -2,19 +2,21 @@ import datetime
from TTS.utils.io import AttrDict
tcolors = AttrDict({
'OKBLUE': '\033[94m',
'HEADER': '\033[95m',
'OKGREEN': '\033[92m',
'WARNING': '\033[93m',
'FAIL': '\033[91m',
'ENDC': '\033[0m',
'BOLD': '\033[1m',
'UNDERLINE': '\033[4m'
})
tcolors = AttrDict(
{
"OKBLUE": "\033[94m",
"HEADER": "\033[95m",
"OKGREEN": "\033[92m",
"WARNING": "\033[93m",
"FAIL": "\033[91m",
"ENDC": "\033[0m",
"BOLD": "\033[1m",
"UNDERLINE": "\033[4m",
}
)
class ConsoleLogger():
class ConsoleLogger:
def __init__(self):
# TODO: color code for value changes
# use these to compare values between iterations
@ -28,23 +30,24 @@ class ConsoleLogger():
return now.strftime("%Y-%m-%d %H:%M:%S")
def print_epoch_start(self, epoch, max_epoch):
print("\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD,
epoch, max_epoch, tcolors.ENDC),
flush=True)
print(
"\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC),
flush=True,
)
def print_train_start(self):
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")
def print_train_step(self, batch_steps, step, global_step, log_dict,
loss_dict, avg_loss_dict):
def print_train_step(self, batch_steps, step, global_step, log_dict, loss_dict, avg_loss_dict):
indent = " | > "
print()
log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(
tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC)
tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC
)
for key, value in loss_dict.items():
# print the avg value if given
if f'avg_{key}' in avg_loss_dict.keys():
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
if f"avg_{key}" in avg_loss_dict.keys():
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"])
else:
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
for idx, (key, value) in enumerate(log_dict.items()):
@ -52,13 +55,12 @@ class ConsoleLogger():
log_text += f"{indent}{key}: {value[0]:.{value[1]}f}"
else:
log_text += f"{indent}{key}: {value}"
if idx < len(log_dict)-1:
if idx < len(log_dict) - 1:
log_text += "\n"
print(log_text, flush=True)
# pylint: disable=unused-argument
def print_train_epoch_end(self, global_step, epoch, epoch_time,
print_dict):
def print_train_epoch_end(self, global_step, epoch, epoch_time, print_dict):
indent = " | > "
log_text = f"\n{tcolors.BOLD} --> TRAIN PERFORMACE -- EPOCH TIME: {epoch_time:.2f} sec -- GLOBAL_STEP: {global_step}{tcolors.ENDC}\n"
for key, value in print_dict.items():
@ -74,29 +76,28 @@ class ConsoleLogger():
log_text = f"{tcolors.BOLD} --> STEP: {step}{tcolors.ENDC}\n"
for key, value in loss_dict.items():
# print the avg value if given
if f'avg_{key}' in avg_loss_dict.keys():
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
if f"avg_{key}" in avg_loss_dict.keys():
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"])
else:
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
print(log_text, flush=True)
def print_epoch_end(self, epoch, avg_loss_dict):
indent = " | > "
log_text = " {}--> EVAL PERFORMANCE{}\n".format(
tcolors.BOLD, tcolors.ENDC)
log_text = " {}--> EVAL PERFORMANCE{}\n".format(tcolors.BOLD, tcolors.ENDC)
for key, value in avg_loss_dict.items():
# print the avg value if given
color = ''
sign = '+'
color = ""
sign = "+"
diff = 0
if self.old_eval_loss_dict is not None and key in self.old_eval_loss_dict:
diff = value - self.old_eval_loss_dict[key]
if diff < 0:
color = tcolors.OKGREEN
sign = ''
sign = ""
elif diff > 0:
color = tcolors.FAIL
sign = '+'
sign = "+"
log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff)
self.old_eval_loss_dict = avg_loss_dict
print(log_text, flush=True)

View File

@ -14,7 +14,7 @@ class DistributedSampler(Sampler):
"""
def __init__(self, dataset, num_replicas=None, rank=None):
super(DistributedSampler, self).__init__(dataset)
super().__init__(dataset)
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
@ -34,11 +34,11 @@ class DistributedSampler(Sampler):
indices = torch.arange(len(self.dataset)).tolist()
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
indices += indices[: (self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
@ -64,12 +64,7 @@ def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url):
torch.cuda.set_device(rank % torch.cuda.device_count())
# Initialize distributed communication
dist.init_process_group(
dist_backend,
init_method=dist_url,
world_size=num_gpus,
rank=rank,
group_name=group_name)
dist.init_process_group(dist_backend, init_method=dist_url, world_size=num_gpus, rank=rank, group_name=group_name)
def apply_gradient_allreduce(module):
@ -97,14 +92,13 @@ def apply_gradient_allreduce(module):
coalesced = _flatten_dense_tensors(grads)
dist.all_reduce(coalesced, op=dist.reduce_op.SUM)
coalesced /= dist.get_world_size()
for buf, synced in zip(
grads, _unflatten_dense_tensors(coalesced, grads)):
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)
for param in list(module.parameters()):
def allreduce_hook(*_):
Variable._execution_engine.queue_callback(allreduce_params) #pylint: disable=protected-access
Variable._execution_engine.queue_callback(allreduce_params) # pylint: disable=protected-access
if param.requires_grad:
param.register_hook(allreduce_hook)

View File

@ -10,8 +10,7 @@ from pathlib import Path
def get_git_branch():
try:
out = subprocess.check_output(["git", "branch"]).decode("utf8")
current = next(line for line in out.split("\n")
if line.startswith("*"))
current = next(line for line in out.split("\n") if line.startswith("*"))
current.replace("* ", "")
except subprocess.CalledProcessError:
current = "inside_docker"
@ -29,12 +28,11 @@ def get_commit_hash():
# raise RuntimeError(
# " !! Commit before training to get the commit hash.")
try:
commit = subprocess.check_output(
['git', 'rev-parse', '--short', 'HEAD']).decode().strip()
commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip()
# Not copying .git folder into docker container
except (subprocess.CalledProcessError, FileNotFoundError):
commit = "0000000"
print(' > Git Hash: {}'.format(commit))
print(" > Git Hash: {}".format(commit))
return commit
@ -42,11 +40,10 @@ def create_experiment_folder(root_path, model_name, debug):
""" Create a folder with the current date and time """
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
if debug:
commit_hash = 'debug'
commit_hash = "debug"
else:
commit_hash = get_commit_hash()
output_folder = os.path.join(
root_path, model_name + '-' + date_str + '-' + commit_hash)
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
os.makedirs(output_folder, exist_ok=True)
print(" > Experiment folder: {}".format(output_folder))
return output_folder
@ -72,16 +69,16 @@ def count_parameters(model):
def get_user_data_dir(appname):
if sys.platform == "win32":
import winreg # pylint: disable=import-outside-toplevel
key = winreg.OpenKey(
winreg.HKEY_CURRENT_USER,
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
)
dir_, _ = winreg.QueryValueEx(key, "Local AppData")
ans = Path(dir_).resolve(strict=False)
elif sys.platform == 'darwin':
ans = Path('~/Library/Application Support/').expanduser()
elif sys.platform == "darwin":
ans = Path("~/Library/Application Support/").expanduser()
else:
ans = Path.home().joinpath('.local/share')
ans = Path.home().joinpath(".local/share")
return ans.joinpath(appname)
@ -91,32 +88,20 @@ def set_init_dict(model_dict, checkpoint_state, c):
if k not in model_dict:
print(" | > Layer missing in the model definition: {}".format(k))
# 1. filter out unnecessary keys
pretrained_dict = {
k: v
for k, v in checkpoint_state.items() if k in model_dict
}
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
# 2. filter out different size layers
pretrained_dict = {
k: v
for k, v in pretrained_dict.items()
if v.numel() == model_dict[k].numel()
}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()}
# 3. skip reinit layers
if c.reinit_layers is not None:
for reinit_layer_name in c.reinit_layers:
pretrained_dict = {
k: v
for k, v in pretrained_dict.items()
if reinit_layer_name not in k
}
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
# 4. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
print(" | > {} / {} layers are restored.".format(len(pretrained_dict),
len(model_dict)))
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
return model_dict
class KeepAverage():
class KeepAverage:
def __init__(self):
self.avg_values = {}
self.iters = {}
@ -141,8 +126,7 @@ class KeepAverage():
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value
self.iters[name] += 1
else:
self.avg_values[name] = self.avg_values[name] * \
self.iters[name] + value
self.avg_values[name] = self.avg_values[name] * self.iters[name] + value
self.iters[name] += 1
self.avg_values[name] /= self.iters[name]
@ -155,23 +139,27 @@ class KeepAverage():
self.update_value(key, value)
def check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None):
def check_argument(
name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None
):
if alternative in c.keys() and c[alternative] is not None:
return
if restricted:
assert name in c.keys(), f' [!] {name} not defined in config.json'
assert name in c.keys(), f" [!] {name} not defined in config.json"
if name in c.keys():
if max_val:
assert c[name] <= max_val, f' [!] {name} is larger than max value {max_val}'
assert c[name] <= max_val, f" [!] {name} is larger than max value {max_val}"
if min_val:
assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}'
assert c[name] >= min_val, f" [!] {name} is smaller than min value {min_val}"
if enum_list:
assert c[name].lower() in enum_list, f' [!] {name} is not a valid value'
assert c[name].lower() in enum_list, f" [!] {name} is not a valid value"
if isinstance(val_type, list):
is_valid = False
for typ in val_type:
if isinstance(c[name], typ):
is_valid = True
assert is_valid or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
assert is_valid or c[name] is None, f" [!] {name} has wrong type - {type(c[name])} vs {val_type}"
elif val_type:
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
assert (
isinstance(c[name], val_type) or c[name] is None
), f" [!] {name} has wrong type - {type(c[name])} vs {val_type}"

View File

@ -8,15 +8,17 @@ from shutil import copyfile
class RenamingUnpickler(pickle_tts.Unpickler):
"""Overload default pickler to solve module renaming problem"""
def find_class(self, module, name):
return super().find_class(module.replace('mozilla_voice_tts', 'TTS'), name)
return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name)
class AttrDict(dict):
"""A custom dict which converts dict keys
to class attributes"""
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self.__dict__ = self
@ -25,11 +27,12 @@ def read_json_with_comments(json_path):
with open(json_path, "r", encoding="utf-8") as f:
input_str = f.read()
# handle comments
input_str = re.sub(r'\\\n', '', input_str)
input_str = re.sub(r'//.*\n', '\n', input_str)
input_str = re.sub(r"\\\n", "", input_str)
input_str = re.sub(r"//.*\n", "\n", input_str)
data = json.loads(input_str)
return data
def load_config(config_path: str) -> AttrDict:
"""Load config files and discard comments
@ -60,7 +63,7 @@ def copy_model_files(c, config_file, out_path, new_fields):
in the config file.
"""
# copy config.json
copy_config_path = os.path.join(out_path, 'config.json')
copy_config_path = os.path.join(out_path, "config.json")
config_lines = open(config_file, "r", encoding="utf-8").readlines()
# add extra information fields
for key, value in new_fields.items():
@ -73,7 +76,10 @@ def copy_model_files(c, config_file, out_path, new_fields):
config_out_file.writelines(config_lines)
config_out_file.close()
# copy model stats file if available
if c.audio['stats_path'] is not None:
copy_stats_path = os.path.join(out_path, 'scale_stats.npy')
if c.audio["stats_path"] is not None:
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
if not os.path.exists(copy_stats_path):
copyfile(c.audio['stats_path'], copy_stats_path, )
copyfile(
c.audio["stats_path"],
copy_stats_path,
)

View File

@ -22,12 +22,13 @@ class ModelManager(object):
Args:
models_file (str): path to .model.json
"""
def __init__(self, models_file=None, output_prefix=None):
super().__init__()
if output_prefix is None:
self.output_prefix = get_user_data_dir('tts')
self.output_prefix = get_user_data_dir("tts")
else:
self.output_prefix = os.path.join(output_prefix, 'tts')
self.output_prefix = os.path.join(output_prefix, "tts")
self.url_prefix = "https://drive.google.com/uc?id="
self.models_dict = None
if models_file is not None:
@ -72,7 +73,7 @@ class ModelManager(object):
print(f" >: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
else:
print(f" >: {model_type}/{lang}/{dataset}/{model}")
models_name_list.append(f'{model_type}/{lang}/{dataset}/{model}')
models_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
return models_name_list
def download_model(self, model_name):
@ -104,25 +105,25 @@ class ModelManager(object):
else:
os.makedirs(output_path, exist_ok=True)
print(f" > Downloading model to {output_path}")
output_stats_path = os.path.join(output_path, 'scale_stats.npy')
output_stats_path = os.path.join(output_path, "scale_stats.npy")
# download files to the output path
if self._check_dict_key(model_item, 'github_rls_url'):
if self._check_dict_key(model_item, "github_rls_url"):
# download from github release
# TODO: pass output_path
self._download_zip_file(model_item['github_rls_url'], output_path)
self._download_zip_file(model_item["github_rls_url"], output_path)
else:
# download from gdrive
self._download_gdrive_file(model_item['model_file'], output_model_path)
self._download_gdrive_file(model_item['config_file'], output_config_path)
if self._check_dict_key(model_item, 'stats_file'):
self._download_gdrive_file(model_item['stats_file'], output_stats_path)
self._download_gdrive_file(model_item["model_file"], output_model_path)
self._download_gdrive_file(model_item["config_file"], output_config_path)
if self._check_dict_key(model_item, "stats_file"):
self._download_gdrive_file(model_item["stats_file"], output_stats_path)
# set the scale_path.npy file path in the model config.json
if self._check_dict_key(model_item, 'stats_file') or os.path.exists(output_stats_path):
if self._check_dict_key(model_item, "stats_file") or os.path.exists(output_stats_path):
# set scale stats path in config.json
config_path = output_config_path
config = load_config(config_path)
config["audio"]['stats_path'] = output_stats_path
config["audio"]["stats_path"] = output_stats_path
with open(config_path, "w") as jf:
json.dump(config, jf)
return output_model_path, output_config_path, model_item

View File

@ -6,7 +6,6 @@ from torch.optim.optimizer import Optimizer
class RAdam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
@ -20,13 +19,15 @@ class RAdam(Optimizer):
self.degenerated_to_sgd = degenerated_to_sgd
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
for param in params:
if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
param['buffer'] = [[None, None, None] for _ in range(10)]
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)])
super(RAdam, self).__init__(params, defaults)
if "betas" in param and (param["betas"][0] != betas[0] or param["betas"][1] != betas[1]):
param["buffer"] = [[None, None, None] for _ in range(10)]
defaults = dict(
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]
)
super().__init__(params, defaults)
def __setstate__(self, state): # pylint: disable=useless-super-delegation
super(RAdam, self).__setstate__(state)
super().__setstate__(state)
def step(self, closure=None):
@ -36,62 +37,70 @@ class RAdam(Optimizer):
for group in self.param_groups:
for p in group['params']:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('RAdam does not support sparse gradients')
raise RuntimeError("RAdam does not support sparse gradients")
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p_data_fp32)
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
state['step'] += 1
buffered = group['buffer'][int(state['step'] % 10)]
if state['step'] == buffered[0]:
state["step"] += 1
buffered = group["buffer"][int(state["step"] % 10)]
if state["step"] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
buffered[0] = state["step"]
beta2_t = beta2 ** state["step"]
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
# more conservative since it's an approximated value
if N_sma >= 5:
step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
step_size = math.sqrt(
(1 - beta2_t)
* (N_sma - 4)
/ (N_sma_max - 4)
* (N_sma - 2)
/ N_sma
* N_sma_max
/ (N_sma_max - 2)
) / (1 - beta1 ** state["step"])
elif self.degenerated_to_sgd:
step_size = 1.0 / (1 - beta1 ** state['step'])
step_size = 1.0 / (1 - beta1 ** state["step"])
else:
step_size = -1
buffered[2] = step_size
# more conservative since it's an approximated value
if N_sma >= 5:
if group['weight_decay'] != 0:
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
denom = exp_avg_sq.sqrt().add_(group["eps"])
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group["lr"])
p.data.copy_(p_data_fp32)
elif step_size > 0:
if group['weight_decay'] != 0:
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
if group["weight_decay"] != 0:
p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
p_data_fp32.add_(exp_avg, alpha=-step_size * group["lr"])
p.data.copy_(p_data_fp32)
return loss

View File

@ -9,6 +9,7 @@ from TTS.utils.io import load_config
from TTS.tts.utils.generic_utils import setup_model
from TTS.tts.utils.speakers import load_speaker_mapping
from TTS.vocoder.utils.generic_utils import setup_generator, interpolate_vocoder_input
# pylint: disable=unused-wildcard-import
# pylint: disable=wildcard-import
from TTS.tts.utils.synthesis import synthesis, trim_silence
@ -49,12 +50,11 @@ class Synthesizer(object):
self.use_cuda = use_cuda
if self.use_cuda:
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
self.load_tts(tts_checkpoint, tts_config,
use_cuda)
self.output_sample_rate = self.tts_config.audio['sample_rate']
self.load_tts(tts_checkpoint, tts_config, use_cuda)
self.output_sample_rate = self.tts_config.audio["sample_rate"]
if vocoder_checkpoint:
self.load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
self.output_sample_rate = self.vocoder_config.audio['sample_rate']
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
@staticmethod
def get_segmenter(lang):
@ -69,16 +69,18 @@ class Synthesizer(object):
self.num_speakers = 0
# set external speaker embedding
if self.tts_config.use_external_speaker_embedding_file:
speaker_embedding = self.tts_speakers[list(self.tts_speakers.keys())[0]]['embedding']
speaker_embedding = self.tts_speakers[list(self.tts_speakers.keys())[0]]["embedding"]
self.speaker_embedding_dim = len(speaker_embedding)
def init_speaker(self, speaker_idx):
# load speakers
speaker_embedding = None
if hasattr(self, 'tts_speakers') and speaker_idx is not None:
assert speaker_idx < len(self.tts_speakers), f" [!] speaker_idx is out of the range. {speaker_idx} vs {len(self.tts_speakers)}"
if hasattr(self, "tts_speakers") and speaker_idx is not None:
assert speaker_idx < len(
self.tts_speakers
), f" [!] speaker_idx is out of the range. {speaker_idx} vs {len(self.tts_speakers)}"
if self.tts_config.use_external_speaker_embedding_file:
speaker_embedding = self.tts_speakers[speaker_idx]['embedding']
speaker_embedding = self.tts_speakers[speaker_idx]["embedding"]
return speaker_embedding
def load_tts(self, tts_checkpoint, tts_config, use_cuda):
@ -90,7 +92,7 @@ class Synthesizer(object):
self.use_phonemes = self.tts_config.use_phonemes
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
if 'characters' in self.tts_config.keys():
if "characters" in self.tts_config.keys():
symbols, phonemes = make_symbols(**self.tts_config.characters)
if self.use_phonemes:
@ -105,7 +107,7 @@ class Synthesizer(object):
def load_vocoder(self, model_file, model_config, use_cuda):
self.vocoder_config = load_config(model_config)
self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config['audio'])
self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config["audio"])
self.vocoder_model = setup_generator(self.vocoder_config)
self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True)
if use_cuda:
@ -141,7 +143,8 @@ class Synthesizer(object):
False,
self.tts_config.enable_eos_bos_chars,
use_gl,
speaker_embedding=speaker_embedding)
speaker_embedding=speaker_embedding,
)
if not use_gl:
# denormalize tts output based on tts audio config
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T
@ -149,7 +152,7 @@ class Synthesizer(object):
# renormalize spectrogram based on vocoder config
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
# compute scale factor for possible sample rate mismatch
scale_factor = [1, self.vocoder_config['audio']['sample_rate'] / self.ap.sample_rate]
scale_factor = [1, self.vocoder_config["audio"]["sample_rate"] / self.ap.sample_rate]
if scale_factor[1] != 1:
print(" > interpolating tts model output.")
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
@ -172,7 +175,7 @@ class Synthesizer(object):
# compute stats
process_time = time.time() - start_time
audio_time = len(wavs) / self.tts_config.audio['sample_rate']
audio_time = len(wavs) / self.tts_config.audio["sample_rate"]
print(f" > Processing time: {process_time}")
print(f" > Real-time factor: {process_time / audio_time}")
return wavs

View File

@ -13,40 +13,28 @@ class TensorboardLogger(object):
layer_num = 1
for name, param in model.named_parameters():
if param.numel() == 1:
self.writer.add_scalar(
"layer{}-{}/value".format(layer_num, name),
param.max(), step)
self.writer.add_scalar("layer{}-{}/value".format(layer_num, name), param.max(), step)
else:
self.writer.add_scalar(
"layer{}-{}/max".format(layer_num, name),
param.max(), step)
self.writer.add_scalar(
"layer{}-{}/min".format(layer_num, name),
param.min(), step)
self.writer.add_scalar(
"layer{}-{}/mean".format(layer_num, name),
param.mean(), step)
self.writer.add_scalar(
"layer{}-{}/std".format(layer_num, name),
param.std(), step)
self.writer.add_histogram(
"layer{}-{}/param".format(layer_num, name), param, step)
self.writer.add_histogram(
"layer{}-{}/grad".format(layer_num, name), param.grad, step)
self.writer.add_scalar("layer{}-{}/max".format(layer_num, name), param.max(), step)
self.writer.add_scalar("layer{}-{}/min".format(layer_num, name), param.min(), step)
self.writer.add_scalar("layer{}-{}/mean".format(layer_num, name), param.mean(), step)
self.writer.add_scalar("layer{}-{}/std".format(layer_num, name), param.std(), step)
self.writer.add_histogram("layer{}-{}/param".format(layer_num, name), param, step)
self.writer.add_histogram("layer{}-{}/grad".format(layer_num, name), param.grad, step)
layer_num += 1
def dict_to_tb_scalar(self, scope_name, stats, step):
for key, value in stats.items():
self.writer.add_scalar('{}/{}'.format(scope_name, key), value, step)
self.writer.add_scalar("{}/{}".format(scope_name, key), value, step)
def dict_to_tb_figure(self, scope_name, figures, step):
for key, value in figures.items():
self.writer.add_figure('{}/{}'.format(scope_name, key), value, step)
self.writer.add_figure("{}/{}".format(scope_name, key), value, step)
def dict_to_tb_audios(self, scope_name, audios, step, sample_rate):
for key, value in audios.items():
try:
self.writer.add_audio('{}/{}'.format(scope_name, key), value, step, sample_rate=sample_rate)
self.writer.add_audio("{}/{}".format(scope_name, key), value, step, sample_rate=sample_rate)
except RuntimeError:
traceback.print_exc()

View File

@ -14,12 +14,13 @@ def setup_torch_training_env(cudnn_enable, cudnn_benchmark):
def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
r'''Check model gradient against unexpected jumps and failures'''
r"""Check model gradient against unexpected jumps and failures"""
skip_flag = False
if ignore_stopnet:
if not amp_opt_params:
grad_norm = torch.nn.utils.clip_grad_norm_(
[param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
[param for name, param in model.named_parameters() if "stopnet" not in name], grad_clip
)
else:
grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip)
else:
@ -41,11 +42,10 @@ def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
def lr_decay(init_lr, global_step, warmup_steps):
r'''from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py'''
r"""from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py"""
warmup_steps = float(warmup_steps)
step = global_step + 1.
lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5,
step**-0.5)
step = global_step + 1.0
lr = init_lr * warmup_steps ** 0.5 * np.minimum(step * warmup_steps ** -1.5, step ** -0.5)
return lr
@ -54,14 +54,14 @@ def adam_weight_decay(optimizer):
Custom weight decay operation, not effecting grad values.
"""
for group in optimizer.param_groups:
for param in group['params']:
current_lr = group['lr']
weight_decay = group['weight_decay']
factor = -weight_decay * group['lr']
param.data = param.data.add(param.data,
alpha=factor)
for param in group["params"]:
current_lr = group["lr"]
weight_decay = group["weight_decay"]
factor = -weight_decay * group["lr"]
param.data = param.data.add(param.data, alpha=factor)
return optimizer, current_lr
# pylint: disable=dangerous-default-value
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}):
"""
@ -74,30 +74,23 @@ def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn
if not param.requires_grad:
continue
if len(param.shape) == 1 or any([skip_name in name for skip_name in skip_list]):
if len(param.shape) == 1 or any((skip_name in name for skip_name in skip_list)):
no_decay.append(param)
else:
decay.append(param)
return [{
'params': no_decay,
'weight_decay': 0.
}, {
'params': decay,
'weight_decay': weight_decay
}]
return [{"params": no_decay, "weight_decay": 0.0}, {"params": decay, "weight_decay": weight_decay}]
# pylint: disable=protected-access
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
self.warmup_steps = float(warmup_steps)
super(NoamLR, self).__init__(optimizer, last_epoch)
super().__init__(optimizer, last_epoch)
def get_lr(self):
step = max(self.last_epoch, 1)
return [
base_lr * self.warmup_steps**0.5 *
min(step * self.warmup_steps**-1.5, step**-0.5)
base_lr * self.warmup_steps ** 0.5 * min(step * self.warmup_steps ** -1.5, step ** -0.5)
for base_lr in self.base_lrs
]

View File

@ -13,19 +13,22 @@ class GANDataset(Dataset):
and converts them to acoustic features on the fly and returns
random segments of (audio, feature) couples.
"""
def __init__(self,
ap,
items,
seq_len,
hop_len,
pad_short,
conv_pad=2,
is_training=True,
return_segments=True,
use_noise_augment=False,
use_cache=False,
verbose=False):
super(GANDataset, self).__init__()
def __init__(
self,
ap,
items,
seq_len,
hop_len,
pad_short,
conv_pad=2,
is_training=True,
return_segments=True,
use_noise_augment=False,
use_cache=False,
verbose=False,
):
super().__init__()
self.ap = ap
self.item_list = items
self.compute_feat = not isinstance(items[0], (tuple, list))
@ -57,14 +60,14 @@ class GANDataset(Dataset):
@staticmethod
def find_wav_files(path):
return glob.glob(os.path.join(path, '**', '*.wav'), recursive=True)
return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True)
def __len__(self):
return len(self.item_list)
def __getitem__(self, idx):
""" Return different items for Generator and Discriminator and
cache acoustic features """
"""Return different items for Generator and Discriminator and
cache acoustic features"""
if self.return_segments:
idx2 = self.G_to_D_mappings[idx]
item1 = self.load_item(idx)
@ -76,13 +79,16 @@ class GANDataset(Dataset):
def _pad_short_samples(self, audio, mel=None):
"""Pad samples shorter than the output sequence length"""
if len(audio) < self.seq_len:
audio = np.pad(audio, (0, self.seq_len - len(audio)),
mode='constant',
constant_values=0.0)
audio = np.pad(audio, (0, self.seq_len - len(audio)), mode="constant", constant_values=0.0)
if mel is not None and mel.shape[1] < self.feat_frame_len:
pad_value = self.ap.melspectrogram(np.zeros([self.ap.win_length]))[:, 0]
mel = np.pad(mel, ([0, 0], [0, self.feat_frame_len - mel.shape[1]]), mode='constant', constant_values=pad_value.mean())
mel = np.pad(
mel,
([0, 0], [0, self.feat_frame_len - mel.shape[1]]),
mode="constant",
constant_values=pad_value.mean(),
)
return audio, mel
def shuffle_mapping(self):
@ -111,12 +117,14 @@ class GANDataset(Dataset):
else:
audio = self.ap.load_wav(wavpath)
mel = np.load(feat_path)
audio, mel= self._pad_short_samples(audio, mel)
audio, mel = self._pad_short_samples(audio, mel)
# correct the audio length wrt padding applied in stft
audio = np.pad(audio, (0, self.hop_len), mode="edge")
audio = audio[:mel.shape[-1] * self.hop_len]
assert mel.shape[-1] * self.hop_len == audio.shape[-1], f' [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}'
audio = audio[: mel.shape[-1] * self.hop_len]
assert (
mel.shape[-1] * self.hop_len == audio.shape[-1]
), f" [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}"
audio = torch.from_numpy(audio).float().unsqueeze(0)
mel = torch.from_numpy(mel).float().squeeze(0)
@ -128,8 +136,7 @@ class GANDataset(Dataset):
mel = mel[:, mel_start:mel_end]
audio_start = mel_start * self.hop_len
audio = audio[:, audio_start:audio_start +
self.seq_len]
audio = audio[:, audio_start : audio_start + self.seq_len]
if self.use_noise_augment and self.is_training and self.return_segments:
audio = audio + (1 / 32768) * torch.randn_like(audio)

View File

@ -18,11 +18,7 @@ def preprocess_wav_files(out_path, config, ap):
mel = ap.melspectrogram(y)
np.save(mel_path, mel)
if isinstance(config.mode, int):
quant = (
ap.mulaw_encode(y, qc=config.mode)
if config.mulaw
else ap.quantize(y, bits=config.mode)
)
quant = ap.mulaw_encode(y, qc=config.mode) if config.mulaw else ap.quantize(y, bits=config.mode)
np.save(quant_path, quant)

View File

@ -13,18 +13,21 @@ class WaveGradDataset(Dataset):
and converts them to acoustic features on the fly and returns
random segments of (audio, feature) couples.
"""
def __init__(self,
ap,
items,
seq_len,
hop_len,
pad_short,
conv_pad=2,
is_training=True,
return_segments=True,
use_noise_augment=False,
use_cache=False,
verbose=False):
def __init__(
self,
ap,
items,
seq_len,
hop_len,
pad_short,
conv_pad=2,
is_training=True,
return_segments=True,
use_noise_augment=False,
use_cache=False,
verbose=False,
):
super().__init__()
self.ap = ap
@ -54,7 +57,7 @@ class WaveGradDataset(Dataset):
@staticmethod
def find_wav_files(path):
return glob.glob(os.path.join(path, '**', '*.wav'), recursive=True)
return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True)
def __len__(self):
return len(self.item_list)
@ -86,13 +89,16 @@ class WaveGradDataset(Dataset):
if self.return_segments:
# correct audio length wrt segment length
if audio.shape[-1] < self.seq_len + self.pad_short:
audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \
mode='constant', constant_values=0.0)
assert audio.shape[-1] >= self.seq_len + self.pad_short, f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}"
audio = np.pad(
audio, (0, self.seq_len + self.pad_short - len(audio)), mode="constant", constant_values=0.0
)
assert (
audio.shape[-1] >= self.seq_len + self.pad_short
), f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}"
# correct the audio length wrt hop length
p = (audio.shape[-1] // self.hop_len + 1) * self.hop_len - audio.shape[-1]
audio = np.pad(audio, (0, p), mode='constant', constant_values=0.0)
audio = np.pad(audio, (0, p), mode="constant", constant_values=0.0)
if self.use_cache:
self.cache[idx] = audio
@ -126,7 +132,7 @@ class WaveGradDataset(Dataset):
for idx, b in enumerate(batch):
mel = b[0]
audio = b[1]
mels[idx, :, :mel.shape[1]] = mel
audios[idx, :audio.shape[0]] = audio
mels[idx, :, : mel.shape[1]] = mel
audios[idx, : audio.shape[0]] = audio
return mels, audios

View File

@ -9,19 +9,20 @@ class WaveRNNDataset(Dataset):
and converts them to acoustic features on the fly.
"""
def __init__(self,
ap,
items,
seq_len,
hop_len,
pad,
mode,
mulaw,
is_training=True,
verbose=False,
):
def __init__(
self,
ap,
items,
seq_len,
hop_len,
pad,
mode,
mulaw,
is_training=True,
verbose=False,
):
super(WaveRNNDataset, self).__init__()
super().__init__()
self.ap = ap
self.compute_feat = not isinstance(items[0], (tuple, list))
self.item_list = items
@ -61,8 +62,9 @@ class WaveRNNDataset(Dataset):
if self.mode in ["gauss", "mold"]:
x_input = audio
elif isinstance(self.mode, int):
x_input = (self.ap.mulaw_encode(audio, qc=self.mode)
if self.mulaw else self.ap.quantize(audio, bits=self.mode))
x_input = (
self.ap.mulaw_encode(audio, qc=self.mode) if self.mulaw else self.ap.quantize(audio, bits=self.mode)
)
else:
raise RuntimeError("Unknown dataset mode - ", self.mode)
@ -71,7 +73,7 @@ class WaveRNNDataset(Dataset):
wavpath, feat_path = self.item_list[index]
mel = np.load(feat_path.replace("/quant/", "/mel/"))
if mel.shape[-1] < self.mel_len + 2 * self.pad:
if mel.shape[-1] < self.mel_len + 2 * self.pad:
print(" [!] Instance is too short! : {}".format(wavpath))
self.item_list[index] = self.item_list[index + 1]
feat_path = self.item_list[index]
@ -87,22 +89,14 @@ class WaveRNNDataset(Dataset):
def collate(self, batch):
mel_win = self.seq_len // self.hop_len + 2 * self.pad
max_offsets = [x[0].shape[-1] -
(mel_win + 2 * self.pad) for x in batch]
max_offsets = [x[0].shape[-1] - (mel_win + 2 * self.pad) for x in batch]
mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
sig_offsets = [(offset + self.pad) *
self.hop_len for offset in mel_offsets]
sig_offsets = [(offset + self.pad) * self.hop_len for offset in mel_offsets]
mels = [
x[0][:, mel_offsets[i]: mel_offsets[i] + mel_win]
for i, x in enumerate(batch)
]
mels = [x[0][:, mel_offsets[i] : mel_offsets[i] + mel_win] for i, x in enumerate(batch)]
coarse = [
x[1][sig_offsets[i]: sig_offsets[i] + self.seq_len + 1]
for i, x in enumerate(batch)
]
coarse = [x[1][sig_offsets[i] : sig_offsets[i] + self.seq_len + 1] for i, x in enumerate(batch)]
mels = np.stack(mels).astype(np.float32)
if self.mode in ["gauss", "mold"]:
@ -112,8 +106,7 @@ class WaveRNNDataset(Dataset):
elif isinstance(self.mode, int):
coarse = np.stack(coarse).astype(np.int64)
coarse = torch.LongTensor(coarse)
x_input = (2 * coarse[:, : self.seq_len].float() /
(2 ** self.mode - 1.0) - 1.0)
x_input = 2 * coarse[:, : self.seq_len].float() / (2 ** self.mode - 1.0) - 1.0
y_coarse = coarse[:, 1:]
mels = torch.FloatTensor(mels)
return x_input, mels, y_coarse

View File

@ -6,18 +6,21 @@ from torch.nn import functional as F
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
"""TODO: Merge this with audio.py"""
def __init__(self,
n_fft,
hop_length,
win_length,
pad_wav=False,
window='hann_window',
sample_rate=None,
mel_fmin=0,
mel_fmax=None,
n_mels=80,
use_mel=False):
super(TorchSTFT, self).__init__()
def __init__(
self,
n_fft,
hop_length,
win_length,
pad_wav=False,
window="hann_window",
sample_rate=None,
mel_fmin=0,
mel_fmax=None,
n_mels=80,
use_mel=False,
):
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
@ -27,8 +30,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
self.mel_fmax = mel_fmax
self.n_mels = n_mels
self.use_mel = use_mel
self.window = nn.Parameter(getattr(torch, window)(win_length),
requires_grad=False)
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
self.mel_basis = None
if use_mel:
self._build_mel_basis()
@ -49,7 +51,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
x = x.unsqueeze(1)
if self.pad_wav:
padding = int((self.n_fft - self.hop_length) / 2)
x = torch.nn.functional.pad(x, (padding, padding), mode='reflect')
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
# B x D x T x 2
o = torch.stft(
x.squeeze(1),
@ -61,35 +63,34 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
pad_mode="reflect", # compatible with audio.py
normalized=False,
onesided=True,
return_complex=False)
return_complex=False,
)
M = o[:, :, :, 0]
P = o[:, :, :, 1]
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
if self.use_mel:
S = torch.matmul(self.mel_basis.to(x), S)
return S
def _build_mel_basis(self):
mel_basis = librosa.filters.mel(self.sample_rate,
self.n_fft,
n_mels=self.n_mels,
fmin=self.mel_fmin,
fmax=self.mel_fmax)
mel_basis = librosa.filters.mel(
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
)
self.mel_basis = torch.from_numpy(mel_basis).float()
#################################
# GENERATOR LOSSES
#################################
class STFTLoss(nn.Module):
""" STFT loss. Input generate and real waveforms are converted
"""STFT loss. Input generate and real waveforms are converted
to spectrograms compared with L1 and Spectral convergence losses.
It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf"""
def __init__(self, n_fft, hop_length, win_length):
super(STFTLoss, self).__init__()
super().__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
@ -104,15 +105,14 @@ class STFTLoss(nn.Module):
loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro")
return loss_mag, loss_sc
class MultiScaleSTFTLoss(torch.nn.Module):
""" Multi-scale STFT loss. Input generate and real waveforms are converted
"""Multi-scale STFT loss. Input generate and real waveforms are converted
to spectrograms compared with L1 and Spectral convergence losses.
It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf"""
def __init__(self,
n_ffts=(1024, 2048, 512),
hop_lengths=(120, 240, 50),
win_lengths=(600, 1200, 240)):
super(MultiScaleSTFTLoss, self).__init__()
def __init__(self, n_ffts=(1024, 2048, 512), hop_lengths=(120, 240, 50), win_lengths=(600, 1200, 240)):
super().__init__()
self.loss_funcs = torch.nn.ModuleList()
for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths):
self.loss_funcs.append(STFTLoss(n_fft, hop_length, win_length))
@ -129,19 +129,25 @@ class MultiScaleSTFTLoss(torch.nn.Module):
loss_mag /= N
return loss_mag, loss_sc
class L1SpecLoss(nn.Module):
""" L1 Loss over Spectrograms as described in HiFiGAN paper https://arxiv.org/pdf/2010.05646.pdf"""
def __init__(self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True):
def __init__(
self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True
):
super().__init__()
self.use_mel = use_mel
self.stft = TorchSTFT(n_fft,
hop_length,
win_length,
sample_rate=sample_rate,
mel_fmin=mel_fmin,
mel_fmax=mel_fmax,
n_mels=n_mels,
use_mel=use_mel)
self.stft = TorchSTFT(
n_fft,
hop_length,
win_length,
sample_rate=sample_rate,
mel_fmin=mel_fmin,
mel_fmax=mel_fmax,
n_mels=n_mels,
use_mel=use_mel,
)
def forward(self, y_hat, y):
y_hat_M = self.stft(y_hat)
@ -150,9 +156,11 @@ class L1SpecLoss(nn.Module):
loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M))
return loss_mag
class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
""" Multiscale STFT loss for multi band model outputs.
"""Multiscale STFT loss for multi band model outputs.
From MultiBand-MelGAN paper https://arxiv.org/abs/2005.05106"""
# pylint: disable=no-self-use
def forward(self, y_hat, y):
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
@ -162,6 +170,7 @@ class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
class MSEGLoss(nn.Module):
""" Mean Squared Generator Loss """
# pylint: disable=no-self-use
def forward(self, score_real):
loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape))
@ -170,10 +179,11 @@ class MSEGLoss(nn.Module):
class HingeGLoss(nn.Module):
""" Hinge Discriminator Loss """
# pylint: disable=no-self-use
def forward(self, score_real):
# TODO: this might be wrong
loss_fake = torch.mean(F.relu(1. - score_real))
loss_fake = torch.mean(F.relu(1.0 - score_real))
return loss_fake
@ -184,8 +194,11 @@ class HingeGLoss(nn.Module):
class MSEDLoss(nn.Module):
""" Mean Squared Discriminator Loss """
def __init__(self,):
super(MSEDLoss, self).__init__()
def __init__(
self,
):
super().__init__()
self.loss_func = nn.MSELoss()
# pylint: disable=no-self-use
@ -198,17 +211,20 @@ class MSEDLoss(nn.Module):
class HingeDLoss(nn.Module):
""" Hinge Discriminator Loss """
# pylint: disable=no-self-use
def forward(self, score_fake, score_real):
loss_real = torch.mean(F.relu(1. - score_real))
loss_fake = torch.mean(F.relu(1. + score_fake))
loss_real = torch.mean(F.relu(1.0 - score_real))
loss_fake = torch.mean(F.relu(1.0 + score_fake))
loss_d = loss_real + loss_fake
return loss_d, loss_real, loss_fake
class MelganFeatureLoss(nn.Module):
def __init__(self,):
super(MelganFeatureLoss, self).__init__()
def __init__(
self,
):
super().__init__()
self.loss_func = nn.L1Loss()
# pylint: disable=no-self-use
@ -229,8 +245,8 @@ class MelganFeatureLoss(nn.Module):
def _apply_G_adv_loss(scores_fake, loss_func):
""" Compute G adversarial loss function
and normalize values """
"""Compute G adversarial loss function
and normalize values"""
adv_loss = 0
if isinstance(scores_fake, list):
for score_fake in scores_fake:
@ -279,24 +295,26 @@ class GeneratorLoss(nn.Module):
Args:
C (AttrDict): model configuration.
"""
def __init__(self, C):
super().__init__()
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
assert not (
C.use_mse_gan_loss and C.use_hinge_gan_loss
), " [!] Cannot use HingeGANLoss and MSEGANLoss together."
self.use_stft_loss = C.use_stft_loss if 'use_stft_loss' in C else False
self.use_subband_stft_loss = C.use_subband_stft_loss if 'use_subband_stft_loss' in C else False
self.use_mse_gan_loss = C.use_mse_gan_loss if 'use_mse_gan_loss' in C else False
self.use_hinge_gan_loss = C.use_hinge_gan_loss if 'use_hinge_gan_loss' in C else False
self.use_feat_match_loss = C.use_feat_match_loss if 'use_feat_match_loss' in C else False
self.use_l1_spec_loss = C.use_l1_spec_loss if 'use_l1_spec_loss' in C else False
self.use_stft_loss = C.use_stft_loss if "use_stft_loss" in C else False
self.use_subband_stft_loss = C.use_subband_stft_loss if "use_subband_stft_loss" in C else False
self.use_mse_gan_loss = C.use_mse_gan_loss if "use_mse_gan_loss" in C else False
self.use_hinge_gan_loss = C.use_hinge_gan_loss if "use_hinge_gan_loss" in C else False
self.use_feat_match_loss = C.use_feat_match_loss if "use_feat_match_loss" in C else False
self.use_l1_spec_loss = C.use_l1_spec_loss if "use_l1_spec_loss" in C else False
self.stft_loss_weight = C.stft_loss_weight if 'stft_loss_weight' in C else 0.0
self.subband_stft_loss_weight = C.subband_stft_loss_weight if 'subband_stft_loss_weight' in C else 0.0
self.mse_gan_loss_weight = C.mse_G_loss_weight if 'mse_G_loss_weight' in C else 0.0
self.hinge_gan_loss_weight = C.hinge_G_loss_weight if 'hinde_G_loss_weight' in C else 0.0
self.feat_match_loss_weight = C.feat_match_loss_weight if 'feat_match_loss_weight' in C else 0.0
self.l1_spec_loss_weight = C.l1_spec_loss_weight if 'l1_spec_loss_weight' in C else 0.0
self.stft_loss_weight = C.stft_loss_weight if "stft_loss_weight" in C else 0.0
self.subband_stft_loss_weight = C.subband_stft_loss_weight if "subband_stft_loss_weight" in C else 0.0
self.mse_gan_loss_weight = C.mse_G_loss_weight if "mse_G_loss_weight" in C else 0.0
self.hinge_gan_loss_weight = C.hinge_G_loss_weight if "hinde_G_loss_weight" in C else 0.0
self.feat_match_loss_weight = C.feat_match_loss_weight if "feat_match_loss_weight" in C else 0.0
self.l1_spec_loss_weight = C.l1_spec_loss_weight if "l1_spec_loss_weight" in C else 0.0
if C.use_stft_loss:
self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params)
@ -309,63 +327,67 @@ class GeneratorLoss(nn.Module):
if C.use_feat_match_loss:
self.feat_match_loss = MelganFeatureLoss()
if C.use_l1_spec_loss:
assert C.audio['sample_rate'] == C.l1_spec_loss_params['sample_rate']
assert C.audio["sample_rate"] == C.l1_spec_loss_params["sample_rate"]
self.l1_spec_loss = L1SpecLoss(**C.l1_spec_loss_params)
def forward(self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None):
def forward(
self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None
):
gen_loss = 0
adv_loss = 0
return_dict = {}
# STFT Loss
if self.use_stft_loss:
stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1))
return_dict['G_stft_loss_mg'] = stft_loss_mg
return_dict['G_stft_loss_sc'] = stft_loss_sc
stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, : y.size(2)].squeeze(1), y.squeeze(1))
return_dict["G_stft_loss_mg"] = stft_loss_mg
return_dict["G_stft_loss_sc"] = stft_loss_sc
gen_loss = gen_loss + self.stft_loss_weight * (stft_loss_mg + stft_loss_sc)
# L1 Spec loss
if self.use_l1_spec_loss:
l1_spec_loss = self.l1_spec_loss(y_hat, y)
return_dict['G_l1_spec_loss'] = l1_spec_loss
return_dict["G_l1_spec_loss"] = l1_spec_loss
gen_loss = gen_loss + self.l1_spec_loss_weight * l1_spec_loss
# subband STFT Loss
if self.use_subband_stft_loss:
subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub)
return_dict['G_subband_stft_loss_mg'] = subband_stft_loss_mg
return_dict['G_subband_stft_loss_sc'] = subband_stft_loss_sc
return_dict["G_subband_stft_loss_mg"] = subband_stft_loss_mg
return_dict["G_subband_stft_loss_sc"] = subband_stft_loss_sc
gen_loss = gen_loss + self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc)
# multiscale MSE adversarial loss
if self.use_mse_gan_loss and scores_fake is not None:
mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss)
return_dict['G_mse_fake_loss'] = mse_fake_loss
return_dict["G_mse_fake_loss"] = mse_fake_loss
adv_loss = adv_loss + self.mse_gan_loss_weight * mse_fake_loss
# multiscale Hinge adversarial loss
if self.use_hinge_gan_loss and not scores_fake is not None:
hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss)
return_dict['G_hinge_fake_loss'] = hinge_fake_loss
return_dict["G_hinge_fake_loss"] = hinge_fake_loss
adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss
# Feature Matching Loss
if self.use_feat_match_loss and not feats_fake is None:
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
return_dict['G_feat_match_loss'] = feat_match_loss
return_dict["G_feat_match_loss"] = feat_match_loss
adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss
return_dict['G_loss'] = gen_loss + adv_loss
return_dict['G_gen_loss'] = gen_loss
return_dict['G_adv_loss'] = adv_loss
return_dict["G_loss"] = gen_loss + adv_loss
return_dict["G_gen_loss"] = gen_loss
return_dict["G_adv_loss"] = adv_loss
return return_dict
class DiscriminatorLoss(nn.Module):
"""Like ```GeneratorLoss```"""
def __init__(self, C):
super().__init__()
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
assert not (
C.use_mse_gan_loss and C.use_hinge_gan_loss
), " [!] Cannot use HingeGANLoss and MSEGANLoss together."
self.use_mse_gan_loss = C.use_mse_gan_loss
self.use_hinge_gan_loss = C.use_hinge_gan_loss
@ -381,23 +403,21 @@ class DiscriminatorLoss(nn.Module):
if self.use_mse_gan_loss:
mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(
scores_fake=scores_fake,
scores_real=scores_real,
loss_func=self.mse_loss)
return_dict['D_mse_gan_loss'] = mse_D_loss
return_dict['D_mse_gan_real_loss'] = mse_D_real_loss
return_dict['D_mse_gan_fake_loss'] = mse_D_fake_loss
scores_fake=scores_fake, scores_real=scores_real, loss_func=self.mse_loss
)
return_dict["D_mse_gan_loss"] = mse_D_loss
return_dict["D_mse_gan_real_loss"] = mse_D_real_loss
return_dict["D_mse_gan_fake_loss"] = mse_D_fake_loss
loss += mse_D_loss
if self.use_hinge_gan_loss:
hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(
scores_fake=scores_fake,
scores_real=scores_real,
loss_func=self.hinge_loss)
return_dict['D_hinge_gan_loss'] = hinge_D_loss
return_dict['D_hinge_gan_real_loss'] = hinge_D_real_loss
return_dict['D_hinge_gan_fake_loss'] = hinge_D_fake_loss
scores_fake=scores_fake, scores_real=scores_real, loss_func=self.hinge_loss
)
return_dict["D_hinge_gan_loss"] = hinge_D_loss
return_dict["D_hinge_gan_real_loss"] = hinge_D_real_loss
return_dict["D_hinge_gan_fake_loss"] = hinge_D_fake_loss
loss += hinge_D_loss
return_dict['D_loss'] = loss
return_dict["D_loss"] = loss
return return_dict

View File

@ -4,7 +4,7 @@ from torch.nn.utils import weight_norm
class ResidualStack(nn.Module):
def __init__(self, channels, num_res_blocks, kernel_size):
super(ResidualStack, self).__init__()
super().__init__()
assert (kernel_size - 1) % 2 == 0, " [!] kernel_size has to be odd."
base_padding = (kernel_size - 1) // 2
@ -12,26 +12,23 @@ class ResidualStack(nn.Module):
self.blocks = nn.ModuleList()
for idx in range(num_res_blocks):
layer_kernel_size = kernel_size
layer_dilation = layer_kernel_size**idx
layer_dilation = layer_kernel_size ** idx
layer_padding = base_padding * layer_dilation
self.blocks += [nn.Sequential(
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(layer_padding),
weight_norm(
nn.Conv1d(channels,
channels,
kernel_size=kernel_size,
dilation=layer_dilation,
bias=True)),
nn.LeakyReLU(0.2),
weight_norm(
nn.Conv1d(channels, channels, kernel_size=1, bias=True)),
)]
self.blocks += [
nn.Sequential(
nn.LeakyReLU(0.2),
nn.ReflectionPad1d(layer_padding),
weight_norm(
nn.Conv1d(channels, channels, kernel_size=kernel_size, dilation=layer_dilation, bias=True)
),
nn.LeakyReLU(0.2),
weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)),
)
]
self.shortcuts = nn.ModuleList([
weight_norm(nn.Conv1d(channels, channels, kernel_size=1,
bias=True)) for i in range(num_res_blocks)
])
self.shortcuts = nn.ModuleList(
[weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for i in range(num_res_blocks)]
)
def forward(self, x):
for block, shortcut in zip(self.blocks, self.shortcuts):

View File

@ -4,54 +4,44 @@ from torch.nn import functional as F
class ResidualBlock(torch.nn.Module):
"""Residual block module in WaveNet."""
def __init__(self,
kernel_size=3,
res_channels=64,
gate_channels=128,
skip_channels=64,
aux_channels=80,
dropout=0.0,
dilation=1,
bias=True,
use_causal_conv=False):
super(ResidualBlock, self).__init__()
def __init__(
self,
kernel_size=3,
res_channels=64,
gate_channels=128,
skip_channels=64,
aux_channels=80,
dropout=0.0,
dilation=1,
bias=True,
use_causal_conv=False,
):
super().__init__()
self.dropout = dropout
# no future time stamps available
if use_causal_conv:
padding = (kernel_size - 1) * dilation
else:
assert (kernel_size -
1) % 2 == 0, "Not support even number kernel size."
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
padding = (kernel_size - 1) // 2 * dilation
self.use_causal_conv = use_causal_conv
# dilation conv
self.conv = torch.nn.Conv1d(res_channels,
gate_channels,
kernel_size,
padding=padding,
dilation=dilation,
bias=bias)
self.conv = torch.nn.Conv1d(
res_channels, gate_channels, kernel_size, padding=padding, dilation=dilation, bias=bias
)
# local conditioning
if aux_channels > 0:
self.conv1x1_aux = torch.nn.Conv1d(aux_channels,
gate_channels,
1,
bias=False)
self.conv1x1_aux = torch.nn.Conv1d(aux_channels, gate_channels, 1, bias=False)
else:
self.conv1x1_aux = None
# conv output is split into two groups
gate_out_channels = gate_channels // 2
self.conv1x1_out = torch.nn.Conv1d(gate_out_channels,
res_channels,
1,
bias=bias)
self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels,
skip_channels,
1,
bias=bias)
self.conv1x1_out = torch.nn.Conv1d(gate_out_channels, res_channels, 1, bias=bias)
self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels, skip_channels, 1, bias=bias)
def forward(self, x, c):
"""
@ -63,7 +53,7 @@ class ResidualBlock(torch.nn.Module):
x = self.conv(x)
# remove future time steps if use_causal_conv conv
x = x[:, :, :residual.size(-1)] if self.use_causal_conv else x
x = x[:, :, : residual.size(-1)] if self.use_causal_conv else x
# split into two part for gated activation
splitdim = 1
@ -82,6 +72,6 @@ class ResidualBlock(torch.nn.Module):
s = self.conv1x1_skip(x)
# for residual connection
x = (self.conv1x1_out(x) + residual) * (0.5**2)
x = (self.conv1x1_out(x) + residual) * (0.5 ** 2)
return x, s

View File

@ -9,21 +9,21 @@ from scipy import signal as sig
# https://github.com/kan-bayashi/ParallelWaveGAN/tree/master/parallel_wavegan
class PQMF(torch.nn.Module):
def __init__(self, N=4, taps=62, cutoff=0.15, beta=9.0):
super(PQMF, self).__init__()
super().__init__()
self.N = N
self.taps = taps
self.cutoff = cutoff
self.beta = beta
QMF = sig.firwin(taps + 1, cutoff, window=('kaiser', beta))
QMF = sig.firwin(taps + 1, cutoff, window=("kaiser", beta))
H = np.zeros((N, len(QMF)))
G = np.zeros((N, len(QMF)))
for k in range(N):
constant_factor = (2 * k + 1) * (np.pi /
(2 * N)) * (np.arange(taps + 1) -
((taps - 1) / 2)) # TODO: (taps - 1) -> taps
phase = (-1)**k * np.pi / 4
constant_factor = (
(2 * k + 1) * (np.pi / (2 * N)) * (np.arange(taps + 1) - ((taps - 1) / 2))
) # TODO: (taps - 1) -> taps
phase = (-1) ** k * np.pi / 4
H[k] = 2 * QMF * np.cos(constant_factor + phase)
G[k] = 2 * QMF * np.cos(constant_factor - phase)
@ -49,8 +49,6 @@ class PQMF(torch.nn.Module):
return F.conv1d(x, self.H, padding=self.taps // 2, stride=self.N)
def synthesis(self, x):
x = F.conv_transpose1d(x,
self.updown_filter * self.N,
stride=self.N)
x = F.conv_transpose1d(x, self.updown_filter * self.N, stride=self.N)
x = F.conv1d(x, self.G, padding=self.taps // 2)
return x

Some files were not shown because too many files have changed in this diff Show More