mirror of https://github.com/coqui-ai/TTS.git
format with black and pylint 2.7.3
parent
5de7eb708b
commit
0e79fa86ad
|
@ -15,16 +15,14 @@ from TTS.utils.audio import AudioProcessor
|
|||
from TTS.utils.io import load_config
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# pylint: disable=bad-continuation
|
||||
if __name__ == "__main__":
|
||||
# pylint: disable=bad-option-value
|
||||
parser = argparse.ArgumentParser(
|
||||
description='''Extract attention masks from trained Tacotron/Tacotron2 models.
|
||||
These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n'''
|
||||
|
||||
'''Each attention mask is written to the same path as the input wav file with ".npy" file extension.
|
||||
(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n'''
|
||||
|
||||
'''
|
||||
description="""Extract attention masks from trained Tacotron/Tacotron2 models.
|
||||
These masks can be used for different purposes including training a TTS model with a Duration Predictor.\n\n"""
|
||||
"""Each attention mask is written to the same path as the input wav file with ".npy" file extension.
|
||||
(e.g. path/bla.wav (wav file) --> path/bla.npy (attention mask))\n"""
|
||||
"""
|
||||
Example run:
|
||||
CUDA_VISIBLE_DEVICE="0" python TTS/bin/compute_attention_masks.py
|
||||
--model_path /data/rw/home/Models/ljspeech-dcattn-December-14-2020_11+10AM-9d0e8c7/checkpoint_200000.pth.tar
|
||||
|
@ -34,53 +32,44 @@ Example run:
|
|||
--batch_size 32
|
||||
--dataset ljspeech
|
||||
--use_cuda True
|
||||
''',
|
||||
formatter_class=RawTextHelpFormatter
|
||||
)
|
||||
parser.add_argument('--model_path',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Path to Tacotron/Tacotron2 model file ')
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Path to Tacotron/Tacotron2 config file.',
|
||||
""",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
parser.add_argument('--dataset',
|
||||
type=str,
|
||||
default='',
|
||||
required=True,
|
||||
help='Target dataset processor name from TTS.tts.dataset.preprocess.')
|
||||
|
||||
parser.add_argument("--model_path", type=str, required=True, help="Path to Tacotron/Tacotron2 model file ")
|
||||
parser.add_argument(
|
||||
'--dataset_metafile',
|
||||
"--config_path",
|
||||
type=str,
|
||||
default='',
|
||||
required=True,
|
||||
help='Dataset metafile inclusing file paths with transcripts.')
|
||||
help="Path to Tacotron/Tacotron2 config file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--data_path',
|
||||
"--dataset",
|
||||
type=str,
|
||||
default='',
|
||||
help='Defines the data path. It overwrites config.json.')
|
||||
parser.add_argument('--use_cuda',
|
||||
type=bool,
|
||||
default=False,
|
||||
help="enable/disable cuda.")
|
||||
default="",
|
||||
required=True,
|
||||
help="Target dataset processor name from TTS.tts.dataset.preprocess.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--batch_size',
|
||||
default=16,
|
||||
type=int,
|
||||
help='Batch size for the model. Use batch_size=1 if you have no CUDA.')
|
||||
"--dataset_metafile",
|
||||
type=str,
|
||||
default="",
|
||||
required=True,
|
||||
help="Dataset metafile inclusing file paths with transcripts.",
|
||||
)
|
||||
parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.")
|
||||
parser.add_argument("--use_cuda", type=bool, default=False, help="enable/disable cuda.")
|
||||
|
||||
parser.add_argument(
|
||||
"--batch_size", default=16, type=int, help="Batch size for the model. Use batch_size=1 if you have no CUDA."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
C = load_config(args.config_path)
|
||||
ap = AudioProcessor(**C.audio)
|
||||
|
||||
# if the vocabulary was passed, replace the default
|
||||
if 'characters' in C.keys():
|
||||
if "characters" in C.keys():
|
||||
symbols, phonemes = make_symbols(**C.characters)
|
||||
|
||||
# load the model
|
||||
|
@ -91,28 +80,32 @@ Example run:
|
|||
model.eval()
|
||||
|
||||
# data loader
|
||||
preprocessor = importlib.import_module('TTS.tts.datasets.preprocess')
|
||||
preprocessor = importlib.import_module("TTS.tts.datasets.preprocess")
|
||||
preprocessor = getattr(preprocessor, args.dataset)
|
||||
meta_data = preprocessor(args.data_path, args.dataset_metafile)
|
||||
dataset = MyDataset(model.decoder.r,
|
||||
C.text_cleaner,
|
||||
compute_linear_spec=False,
|
||||
ap=ap,
|
||||
meta_data=meta_data,
|
||||
tp=C.characters if 'characters' in C.keys() else None,
|
||||
add_blank=C['add_blank'] if 'add_blank' in C.keys() else False,
|
||||
use_phonemes=C.use_phonemes,
|
||||
phoneme_cache_path=C.phoneme_cache_path,
|
||||
phoneme_language=C.phoneme_language,
|
||||
enable_eos_bos=C.enable_eos_bos_chars)
|
||||
dataset = MyDataset(
|
||||
model.decoder.r,
|
||||
C.text_cleaner,
|
||||
compute_linear_spec=False,
|
||||
ap=ap,
|
||||
meta_data=meta_data,
|
||||
tp=C.characters if "characters" in C.keys() else None,
|
||||
add_blank=C["add_blank"] if "add_blank" in C.keys() else False,
|
||||
use_phonemes=C.use_phonemes,
|
||||
phoneme_cache_path=C.phoneme_cache_path,
|
||||
phoneme_language=C.phoneme_language,
|
||||
enable_eos_bos=C.enable_eos_bos_chars,
|
||||
)
|
||||
|
||||
dataset.sort_items()
|
||||
loader = DataLoader(dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=4,
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=False,
|
||||
drop_last=False)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=4,
|
||||
collate_fn=dataset.collate_fn,
|
||||
shuffle=False,
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
# compute attentions
|
||||
file_paths = []
|
||||
|
@ -134,25 +127,29 @@ Example run:
|
|||
mel_input = mel_input.cuda()
|
||||
mel_lengths = mel_lengths.cuda()
|
||||
|
||||
mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(
|
||||
text_input, text_lengths, mel_input)
|
||||
mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)
|
||||
|
||||
alignments = alignments.detach()
|
||||
for idx, alignment in enumerate(alignments):
|
||||
item_idx = item_idxs[idx]
|
||||
# interpolate if r > 1
|
||||
alignment = torch.nn.functional.interpolate(
|
||||
alignment.transpose(0, 1).unsqueeze(0),
|
||||
size=None,
|
||||
scale_factor=model.decoder.r,
|
||||
mode='nearest',
|
||||
align_corners=None,
|
||||
recompute_scale_factor=None).squeeze(0).transpose(0, 1)
|
||||
alignment = (
|
||||
torch.nn.functional.interpolate(
|
||||
alignment.transpose(0, 1).unsqueeze(0),
|
||||
size=None,
|
||||
scale_factor=model.decoder.r,
|
||||
mode="nearest",
|
||||
align_corners=None,
|
||||
recompute_scale_factor=None,
|
||||
)
|
||||
.squeeze(0)
|
||||
.transpose(0, 1)
|
||||
)
|
||||
# remove paddings
|
||||
alignment = alignment[:mel_lengths[idx], :text_lengths[idx]].cpu().numpy()
|
||||
alignment = alignment[: mel_lengths[idx], : text_lengths[idx]].cpu().numpy()
|
||||
# set file paths
|
||||
wav_file_name = os.path.basename(item_idx)
|
||||
align_file_name = os.path.splitext(wav_file_name)[0] + '.npy'
|
||||
align_file_name = os.path.splitext(wav_file_name)[0] + ".npy"
|
||||
file_path = item_idx.replace(wav_file_name, align_file_name)
|
||||
# save output
|
||||
file_paths.append([item_idx, file_path])
|
||||
|
|
|
@ -13,91 +13,72 @@ from TTS.tts.utils.speakers import save_speaker_mapping
|
|||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Compute embedding vectors for each wav file in a dataset. If "target_dataset" is defined, it generates "speakers.json" necessary for training a multi-speaker model.')
|
||||
parser.add_argument(
|
||||
'model_path',
|
||||
type=str,
|
||||
help='Path to model outputs (checkpoint, tensorboard etc.).')
|
||||
parser.add_argument(
|
||||
'config_path',
|
||||
type=str,
|
||||
help='Path to config file for training.',
|
||||
description='Compute embedding vectors for each wav file in a dataset. If "target_dataset" is defined, it generates "speakers.json" necessary for training a multi-speaker model.'
|
||||
)
|
||||
parser.add_argument("model_path", type=str, help="Path to model outputs (checkpoint, tensorboard etc.).")
|
||||
parser.add_argument(
|
||||
'data_path',
|
||||
"config_path",
|
||||
type=str,
|
||||
help='Data path for wav files - directory or CSV file')
|
||||
help="Path to config file for training.",
|
||||
)
|
||||
parser.add_argument("data_path", type=str, help="Data path for wav files - directory or CSV file")
|
||||
parser.add_argument("output_path", type=str, help="path for training outputs.")
|
||||
parser.add_argument(
|
||||
'output_path',
|
||||
"--target_dataset",
|
||||
type=str,
|
||||
help='path for training outputs.')
|
||||
parser.add_argument(
|
||||
'--target_dataset',
|
||||
type=str,
|
||||
default='',
|
||||
help='Target dataset to pick a processor from TTS.tts.dataset.preprocess. Necessary to create a speakers.json file.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use_cuda', type=bool, help='flag to set cuda.', default=False
|
||||
)
|
||||
parser.add_argument(
|
||||
'--separator', type=str, help='Separator used in file if CSV is passed for data_path', default='|'
|
||||
default="",
|
||||
help="Target dataset to pick a processor from TTS.tts.dataset.preprocess. Necessary to create a speakers.json file.",
|
||||
)
|
||||
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=False)
|
||||
parser.add_argument("--separator", type=str, help="Separator used in file if CSV is passed for data_path", default="|")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
c = load_config(args.config_path)
|
||||
ap = AudioProcessor(**c['audio'])
|
||||
ap = AudioProcessor(**c["audio"])
|
||||
|
||||
data_path = args.data_path
|
||||
split_ext = os.path.splitext(data_path)
|
||||
sep = args.separator
|
||||
|
||||
if args.target_dataset != '':
|
||||
if args.target_dataset != "":
|
||||
# if target dataset is defined
|
||||
dataset_config = [
|
||||
{
|
||||
"name": args.target_dataset,
|
||||
"path": args.data_path,
|
||||
"meta_file_train": None,
|
||||
"meta_file_val": None
|
||||
},
|
||||
{"name": args.target_dataset, "path": args.data_path, "meta_file_train": None, "meta_file_val": None},
|
||||
]
|
||||
wav_files, _ = load_meta_data(dataset_config, eval_split=False)
|
||||
output_files = [wav_file[1].replace(data_path, args.output_path).replace(
|
||||
'.wav', '.npy') for wav_file in wav_files]
|
||||
output_files = [wav_file[1].replace(data_path, args.output_path).replace(".wav", ".npy") for wav_file in wav_files]
|
||||
else:
|
||||
# if target dataset is not defined
|
||||
if len(split_ext) > 0 and split_ext[1].lower() == '.csv':
|
||||
if len(split_ext) > 0 and split_ext[1].lower() == ".csv":
|
||||
# Parse CSV
|
||||
print(f'CSV file: {data_path}')
|
||||
print(f"CSV file: {data_path}")
|
||||
with open(data_path) as f:
|
||||
wav_path = os.path.join(os.path.dirname(data_path), 'wavs')
|
||||
wav_path = os.path.join(os.path.dirname(data_path), "wavs")
|
||||
wav_files = []
|
||||
print(f'Separator is: {sep}')
|
||||
print(f"Separator is: {sep}")
|
||||
for line in f:
|
||||
components = line.split(sep)
|
||||
if len(components) != 2:
|
||||
print("Invalid line")
|
||||
continue
|
||||
wav_file = os.path.join(wav_path, components[0] + '.wav')
|
||||
#print(f'wav_file: {wav_file}')
|
||||
wav_file = os.path.join(wav_path, components[0] + ".wav")
|
||||
# print(f'wav_file: {wav_file}')
|
||||
if os.path.exists(wav_file):
|
||||
wav_files.append(wav_file)
|
||||
print(f'Count of wavs imported: {len(wav_files)}')
|
||||
print(f"Count of wavs imported: {len(wav_files)}")
|
||||
else:
|
||||
# Parse all wav files in data_path
|
||||
wav_files = glob.glob(data_path + '/**/*.wav', recursive=True)
|
||||
wav_files = glob.glob(data_path + "/**/*.wav", recursive=True)
|
||||
|
||||
output_files = [wav_file.replace(data_path, args.output_path).replace(
|
||||
'.wav', '.npy') for wav_file in wav_files]
|
||||
output_files = [wav_file.replace(data_path, args.output_path).replace(".wav", ".npy") for wav_file in wav_files]
|
||||
|
||||
for output_file in output_files:
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
|
||||
# define Encoder model
|
||||
model = SpeakerEncoder(**c.model)
|
||||
model.load_state_dict(torch.load(args.model_path)['model'])
|
||||
model.load_state_dict(torch.load(args.model_path)["model"])
|
||||
model.eval()
|
||||
if args.use_cuda:
|
||||
model.cuda()
|
||||
|
@ -117,14 +98,14 @@ for idx, wav_file in enumerate(tqdm(wav_files)):
|
|||
embedd = embedd.detach().cpu().numpy()
|
||||
np.save(output_files[idx], embedd)
|
||||
|
||||
if args.target_dataset != '':
|
||||
if args.target_dataset != "":
|
||||
# create speaker_mapping if target dataset is defined
|
||||
wav_file_name = os.path.basename(wav_file)
|
||||
speaker_mapping[wav_file_name] = {}
|
||||
speaker_mapping[wav_file_name]['name'] = speaker_name
|
||||
speaker_mapping[wav_file_name]['embedding'] = embedd.flatten().tolist()
|
||||
speaker_mapping[wav_file_name]["name"] = speaker_name
|
||||
speaker_mapping[wav_file_name]["embedding"] = embedd.flatten().tolist()
|
||||
|
||||
if args.target_dataset != '':
|
||||
if args.target_dataset != "":
|
||||
# save speaker_mapping if target dataset is defined
|
||||
mapping_file_path = os.path.join(args.output_path, 'speakers.json')
|
||||
mapping_file_path = os.path.join(args.output_path, "speakers.json")
|
||||
save_speaker_mapping(args.output_path, speaker_mapping)
|
||||
|
|
|
@ -15,25 +15,24 @@ from TTS.utils.audio import AudioProcessor
|
|||
|
||||
def main():
|
||||
"""Run preprocessing process."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compute mean and variance of spectrogtram features.")
|
||||
parser.add_argument("--config_path", type=str, required=True,
|
||||
help="TTS config file path to define audio processin parameters.")
|
||||
parser.add_argument("--out_path", type=str, required=True,
|
||||
help="save path (directory and filename).")
|
||||
parser = argparse.ArgumentParser(description="Compute mean and variance of spectrogtram features.")
|
||||
parser.add_argument(
|
||||
"--config_path", type=str, required=True, help="TTS config file path to define audio processin parameters."
|
||||
)
|
||||
parser.add_argument("--out_path", type=str, required=True, help="save path (directory and filename).")
|
||||
args = parser.parse_args()
|
||||
|
||||
# load config
|
||||
CONFIG = load_config(args.config_path)
|
||||
CONFIG.audio['signal_norm'] = False # do not apply earlier normalization
|
||||
CONFIG.audio['stats_path'] = None # discard pre-defined stats
|
||||
CONFIG.audio["signal_norm"] = False # do not apply earlier normalization
|
||||
CONFIG.audio["stats_path"] = None # discard pre-defined stats
|
||||
|
||||
# load audio processor
|
||||
ap = AudioProcessor(**CONFIG.audio)
|
||||
|
||||
# load the meta data of target dataset
|
||||
if 'data_path' in CONFIG.keys():
|
||||
dataset_items = glob.glob(os.path.join(CONFIG.data_path, '**', '*.wav'), recursive=True)
|
||||
if "data_path" in CONFIG.keys():
|
||||
dataset_items = glob.glob(os.path.join(CONFIG.data_path, "**", "*.wav"), recursive=True)
|
||||
else:
|
||||
dataset_items = load_meta_data(CONFIG.datasets)[0] # take only train data
|
||||
print(f" > There are {len(dataset_items)} files.")
|
||||
|
@ -63,27 +62,27 @@ def main():
|
|||
|
||||
output_file_path = args.out_path
|
||||
stats = {}
|
||||
stats['mel_mean'] = mel_mean
|
||||
stats['mel_std'] = mel_scale
|
||||
stats['linear_mean'] = linear_mean
|
||||
stats['linear_std'] = linear_scale
|
||||
stats["mel_mean"] = mel_mean
|
||||
stats["mel_std"] = mel_scale
|
||||
stats["linear_mean"] = linear_mean
|
||||
stats["linear_std"] = linear_scale
|
||||
|
||||
print(f' > Avg mel spec mean: {mel_mean.mean()}')
|
||||
print(f' > Avg mel spec scale: {mel_scale.mean()}')
|
||||
print(f' > Avg linear spec mean: {linear_mean.mean()}')
|
||||
print(f' > Avg lienar spec scale: {linear_scale.mean()}')
|
||||
print(f" > Avg mel spec mean: {mel_mean.mean()}")
|
||||
print(f" > Avg mel spec scale: {mel_scale.mean()}")
|
||||
print(f" > Avg linear spec mean: {linear_mean.mean()}")
|
||||
print(f" > Avg lienar spec scale: {linear_scale.mean()}")
|
||||
|
||||
# set default config values for mean-var scaling
|
||||
CONFIG.audio['stats_path'] = output_file_path
|
||||
CONFIG.audio['signal_norm'] = True
|
||||
CONFIG.audio["stats_path"] = output_file_path
|
||||
CONFIG.audio["signal_norm"] = True
|
||||
# remove redundant values
|
||||
del CONFIG.audio['max_norm']
|
||||
del CONFIG.audio['min_level_db']
|
||||
del CONFIG.audio['symmetric_norm']
|
||||
del CONFIG.audio['clip_norm']
|
||||
stats['audio_config'] = CONFIG.audio
|
||||
del CONFIG.audio["max_norm"]
|
||||
del CONFIG.audio["min_level_db"]
|
||||
del CONFIG.audio["symmetric_norm"]
|
||||
del CONFIG.audio["clip_norm"]
|
||||
stats["audio_config"] = CONFIG.audio
|
||||
np.save(output_file_path, stats, allow_pickle=True)
|
||||
print(f' > stats saved to {output_file_path}')
|
||||
print(f" > stats saved to {output_file_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -9,15 +9,9 @@ from TTS.vocoder.tf.utils.tflite import convert_melgan_to_tflite
|
|||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--tf_model',
|
||||
type=str,
|
||||
help='Path to target torch model to be converted to TF.')
|
||||
parser.add_argument('--config_path',
|
||||
type=str,
|
||||
help='Path to config file of torch model.')
|
||||
parser.add_argument('--output_path',
|
||||
type=str,
|
||||
help='path to tflite output binary.')
|
||||
parser.add_argument("--tf_model", type=str, help="Path to target torch model to be converted to TF.")
|
||||
parser.add_argument("--config_path", type=str, help="Path to config file of torch model.")
|
||||
parser.add_argument("--output_path", type=str, help="path to tflite output binary.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set constants
|
||||
|
|
|
@ -8,27 +8,22 @@ import torch
|
|||
|
||||
from TTS.utils.io import load_config
|
||||
from TTS.vocoder.tf.utils.convert_torch_to_tf_utils import (
|
||||
compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf)
|
||||
from TTS.vocoder.tf.utils.generic_utils import \
|
||||
setup_generator as setup_tf_generator
|
||||
compare_torch_tf,
|
||||
convert_tf_name,
|
||||
transfer_weights_torch_to_tf,
|
||||
)
|
||||
from TTS.vocoder.tf.utils.generic_utils import setup_generator as setup_tf_generator
|
||||
from TTS.vocoder.tf.utils.io import save_checkpoint
|
||||
from TTS.vocoder.utils.generic_utils import setup_generator
|
||||
|
||||
# prevent GPU use
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
|
||||
# define args
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--torch_model_path',
|
||||
type=str,
|
||||
help='Path to target torch model to be converted to TF.')
|
||||
parser.add_argument('--config_path',
|
||||
type=str,
|
||||
help='Path to config file of torch model.')
|
||||
parser.add_argument(
|
||||
'--output_path',
|
||||
type=str,
|
||||
help='path to output file including file name to save TF model.')
|
||||
parser.add_argument("--torch_model_path", type=str, help="Path to target torch model to be converted to TF.")
|
||||
parser.add_argument("--config_path", type=str, help="Path to config file of torch model.")
|
||||
parser.add_argument("--output_path", type=str, help="path to output file including file name to save TF model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# load model config
|
||||
|
@ -38,9 +33,8 @@ num_speakers = 0
|
|||
|
||||
# init torch model
|
||||
model = setup_generator(c)
|
||||
checkpoint = torch.load(args.torch_model_path,
|
||||
map_location=torch.device('cpu'))
|
||||
state_dict = checkpoint['model']
|
||||
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
|
||||
state_dict = checkpoint["model"]
|
||||
model.load_state_dict(state_dict)
|
||||
model.remove_weight_norm()
|
||||
state_dict = model.state_dict()
|
||||
|
@ -48,7 +42,7 @@ state_dict = model.state_dict()
|
|||
# init tf model
|
||||
model_tf = setup_tf_generator(c)
|
||||
|
||||
common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
|
||||
common_sufix = "/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
# get tf_model graph by passing an input
|
||||
# B x D x T
|
||||
dummy_input = tf.random.uniform((7, 80, 64), dtype=tf.float32)
|
||||
|
@ -66,10 +60,7 @@ for tf_name in tf_var_names:
|
|||
if tf_name in [name[0] for name in var_map]:
|
||||
continue
|
||||
tf_name_edited = convert_tf_name(tf_name)
|
||||
ratios = [
|
||||
SequenceMatcher(None, torch_name, tf_name_edited).ratio()
|
||||
for torch_name in torch_var_names
|
||||
]
|
||||
ratios = [SequenceMatcher(None, torch_name, tf_name_edited).ratio() for torch_name in torch_var_names]
|
||||
max_idx = np.argmax(ratios)
|
||||
matching_name = torch_var_names[max_idx]
|
||||
del torch_var_names[max_idx]
|
||||
|
@ -107,10 +98,8 @@ model.inference_padding = 0
|
|||
model_tf.inference_padding = 0
|
||||
output_torch = model.inference(dummy_input_torch)
|
||||
output_tf = model_tf(dummy_input_tf, training=False)
|
||||
assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf(
|
||||
output_torch, output_tf)
|
||||
assert compare_torch_tf(output_torch, output_tf) < 1e-5, compare_torch_tf(output_torch, output_tf)
|
||||
|
||||
# save tf model
|
||||
save_checkpoint(model_tf, checkpoint['step'], checkpoint['epoch'],
|
||||
args.output_path)
|
||||
print(' > Model conversion is successfully completed :).')
|
||||
save_checkpoint(model_tf, checkpoint["step"], checkpoint["epoch"], args.output_path)
|
||||
print(" > Model conversion is successfully completed :).")
|
||||
|
|
|
@ -10,15 +10,9 @@ from TTS.tts.tf.utils.tflite import convert_tacotron2_to_tflite
|
|||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--tf_model',
|
||||
type=str,
|
||||
help='Path to target torch model to be converted to TF.')
|
||||
parser.add_argument('--config_path',
|
||||
type=str,
|
||||
help='Path to config file of torch model.')
|
||||
parser.add_argument('--output_path',
|
||||
type=str,
|
||||
help='path to tflite output binary.')
|
||||
parser.add_argument("--tf_model", type=str, help="Path to target torch model to be converted to TF.")
|
||||
parser.add_argument("--config_path", type=str, help="Path to config file of torch model.")
|
||||
parser.add_argument("--output_path", type=str, help="path to tflite output binary.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set constants
|
||||
|
|
|
@ -8,27 +8,20 @@ import numpy as np
|
|||
import tensorflow as tf
|
||||
import torch
|
||||
from TTS.tts.tf.models.tacotron2 import Tacotron2
|
||||
from TTS.tts.tf.utils.convert_torch_to_tf_utils import (
|
||||
compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf)
|
||||
from TTS.tts.tf.utils.convert_torch_to_tf_utils import compare_torch_tf, convert_tf_name, transfer_weights_torch_to_tf
|
||||
from TTS.tts.tf.utils.generic_utils import save_checkpoint
|
||||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.utils.text.symbols import phonemes, symbols
|
||||
from TTS.utils.io import load_config
|
||||
|
||||
sys.path.append('/home/erogol/Projects')
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = ''
|
||||
sys.path.append("/home/erogol/Projects")
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = ""
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--torch_model_path',
|
||||
type=str,
|
||||
help='Path to target torch model to be converted to TF.')
|
||||
parser.add_argument('--config_path',
|
||||
type=str,
|
||||
help='Path to config file of torch model.')
|
||||
parser.add_argument('--output_path',
|
||||
type=str,
|
||||
help='path to output file including file name to save TF model.')
|
||||
parser.add_argument("--torch_model_path", type=str, help="Path to target torch model to be converted to TF.")
|
||||
parser.add_argument("--config_path", type=str, help="Path to config file of torch model.")
|
||||
parser.add_argument("--output_path", type=str, help="path to output file including file name to save TF model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# load model config
|
||||
|
@ -39,51 +32,48 @@ num_speakers = 0
|
|||
# init torch model
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
model = setup_model(num_chars, num_speakers, c)
|
||||
checkpoint = torch.load(args.torch_model_path,
|
||||
map_location=torch.device('cpu'))
|
||||
state_dict = checkpoint['model']
|
||||
checkpoint = torch.load(args.torch_model_path, map_location=torch.device("cpu"))
|
||||
state_dict = checkpoint["model"]
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
# init tf model
|
||||
model_tf = Tacotron2(num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=model.decoder.r,
|
||||
postnet_output_dim=c.audio['num_mels'],
|
||||
decoder_output_dim=c.audio['num_mels'],
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder)
|
||||
model_tf = Tacotron2(
|
||||
num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=model.decoder.r,
|
||||
postnet_output_dim=c.audio["num_mels"],
|
||||
decoder_output_dim=c.audio["num_mels"],
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder,
|
||||
)
|
||||
|
||||
# set initial layer mapping - these are not captured by the below heuristic approach
|
||||
# TODO: set layer names so that we can remove these manual matching
|
||||
common_sufix = '/.ATTRIBUTES/VARIABLE_VALUE'
|
||||
common_sufix = "/.ATTRIBUTES/VARIABLE_VALUE"
|
||||
var_map = [
|
||||
('embedding/embeddings:0', 'embedding.weight'),
|
||||
('encoder/lstm/forward_lstm/lstm_cell_1/kernel:0',
|
||||
'encoder.lstm.weight_ih_l0'),
|
||||
('encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0',
|
||||
'encoder.lstm.weight_hh_l0'),
|
||||
('encoder/lstm/backward_lstm/lstm_cell_2/kernel:0',
|
||||
'encoder.lstm.weight_ih_l0_reverse'),
|
||||
('encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0',
|
||||
'encoder.lstm.weight_hh_l0_reverse'),
|
||||
('encoder/lstm/forward_lstm/lstm_cell_1/bias:0',
|
||||
('encoder.lstm.bias_ih_l0', 'encoder.lstm.bias_hh_l0')),
|
||||
('encoder/lstm/backward_lstm/lstm_cell_2/bias:0',
|
||||
('encoder.lstm.bias_ih_l0_reverse', 'encoder.lstm.bias_hh_l0_reverse')),
|
||||
('attention/v/kernel:0', 'decoder.attention.v.linear_layer.weight'),
|
||||
('decoder/linear_projection/kernel:0',
|
||||
'decoder.linear_projection.linear_layer.weight'),
|
||||
('decoder/stopnet/kernel:0', 'decoder.stopnet.1.linear_layer.weight')
|
||||
("embedding/embeddings:0", "embedding.weight"),
|
||||
("encoder/lstm/forward_lstm/lstm_cell_1/kernel:0", "encoder.lstm.weight_ih_l0"),
|
||||
("encoder/lstm/forward_lstm/lstm_cell_1/recurrent_kernel:0", "encoder.lstm.weight_hh_l0"),
|
||||
("encoder/lstm/backward_lstm/lstm_cell_2/kernel:0", "encoder.lstm.weight_ih_l0_reverse"),
|
||||
("encoder/lstm/backward_lstm/lstm_cell_2/recurrent_kernel:0", "encoder.lstm.weight_hh_l0_reverse"),
|
||||
("encoder/lstm/forward_lstm/lstm_cell_1/bias:0", ("encoder.lstm.bias_ih_l0", "encoder.lstm.bias_hh_l0")),
|
||||
(
|
||||
"encoder/lstm/backward_lstm/lstm_cell_2/bias:0",
|
||||
("encoder.lstm.bias_ih_l0_reverse", "encoder.lstm.bias_hh_l0_reverse"),
|
||||
),
|
||||
("attention/v/kernel:0", "decoder.attention.v.linear_layer.weight"),
|
||||
("decoder/linear_projection/kernel:0", "decoder.linear_projection.linear_layer.weight"),
|
||||
("decoder/stopnet/kernel:0", "decoder.stopnet.1.linear_layer.weight"),
|
||||
]
|
||||
|
||||
# %%
|
||||
|
@ -101,10 +91,7 @@ for tf_name in tf_var_names:
|
|||
if tf_name in [name[0] for name in var_map]:
|
||||
continue
|
||||
tf_name_edited = convert_tf_name(tf_name)
|
||||
ratios = [
|
||||
SequenceMatcher(None, torch_name, tf_name_edited).ratio()
|
||||
for torch_name in torch_var_names
|
||||
]
|
||||
ratios = [SequenceMatcher(None, torch_name, tf_name_edited).ratio() for torch_name in torch_var_names]
|
||||
max_idx = np.argmax(ratios)
|
||||
matching_name = torch_var_names[max_idx]
|
||||
del torch_var_names[max_idx]
|
||||
|
@ -124,25 +111,21 @@ input_ids = torch.randint(0, 24, (1, 128)).long()
|
|||
|
||||
o_t = model.embedding(input_ids)
|
||||
o_tf = model_tf.embedding(input_ids.detach().numpy())
|
||||
assert abs(o_t.detach().numpy() -
|
||||
o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() -
|
||||
o_tf.numpy()).sum()
|
||||
assert abs(o_t.detach().numpy() - o_tf.numpy()).sum() < 1e-5, abs(o_t.detach().numpy() - o_tf.numpy()).sum()
|
||||
|
||||
# compare encoder outputs
|
||||
oo_en = model.encoder.inference(o_t.transpose(1, 2))
|
||||
ooo_en = model_tf.encoder(o_t.detach().numpy(), training=False)
|
||||
assert compare_torch_tf(oo_en, ooo_en) < 1e-5
|
||||
|
||||
#pylint: disable=redefined-builtin
|
||||
# pylint: disable=redefined-builtin
|
||||
# compare decoder.attention_rnn
|
||||
inp = torch.rand([1, 768])
|
||||
inp_tf = inp.numpy()
|
||||
model.decoder._init_states(oo_en, mask=None) #pylint: disable=protected-access
|
||||
model.decoder._init_states(oo_en, mask=None) # pylint: disable=protected-access
|
||||
output, cell_state = model.decoder.attention_rnn(inp)
|
||||
states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
|
||||
output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf,
|
||||
states[2],
|
||||
training=False)
|
||||
output_tf, memory_state = model_tf.decoder.attention_rnn(inp_tf, states[2], training=False)
|
||||
assert compare_torch_tf(output, output_tf).mean() < 1e-5
|
||||
|
||||
query = output
|
||||
|
@ -153,8 +136,7 @@ inputs_tf = inputs.numpy()
|
|||
# compare decoder.attention
|
||||
model.decoder.attention.init_states(inputs)
|
||||
processes_inputs = model.decoder.attention.preprocess_inputs(inputs)
|
||||
loc_attn, proc_query = model.decoder.attention.get_location_attention(
|
||||
query, processes_inputs)
|
||||
loc_attn, proc_query = model.decoder.attention.get_location_attention(query, processes_inputs)
|
||||
context = model.decoder.attention(query, inputs, processes_inputs, None)
|
||||
|
||||
attention_states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)[-1]
|
||||
|
@ -169,13 +151,10 @@ assert compare_torch_tf(context, context_tf) < 1e-5
|
|||
# compare decoder.decoder_rnn
|
||||
input = torch.rand([1, 1536])
|
||||
input_tf = input.numpy()
|
||||
model.decoder._init_states(oo_en, mask=None) #pylint: disable=protected-access
|
||||
output, cell_state = model.decoder.decoder_rnn(
|
||||
input, [model.decoder.decoder_hidden, model.decoder.decoder_cell])
|
||||
model.decoder._init_states(oo_en, mask=None) # pylint: disable=protected-access
|
||||
output, cell_state = model.decoder.decoder_rnn(input, [model.decoder.decoder_hidden, model.decoder.decoder_cell])
|
||||
states = model_tf.decoder.build_decoder_initial_states(1, 512, 128)
|
||||
output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf,
|
||||
states[3],
|
||||
training=False)
|
||||
output_tf, memory_state = model_tf.decoder.decoder_rnn(input_tf, states[3], training=False)
|
||||
assert abs(input - input_tf).mean() < 1e-5
|
||||
assert compare_torch_tf(output, output_tf).mean() < 1e-5
|
||||
|
||||
|
@ -198,12 +177,10 @@ assert compare_torch_tf(output.transpose(1, 2), output_tf) < 1e-4
|
|||
outputs_torch = model.inference(input_ids)
|
||||
outputs_tf = model_tf(tf.convert_to_tensor(input_ids.numpy()))
|
||||
print(abs(outputs_torch[0].numpy()[:, 0] - outputs_tf[0].numpy()[:, 0]).mean())
|
||||
assert compare_torch_tf(outputs_torch[2][:, 50, :],
|
||||
outputs_tf[2][:, 50, :]) < 1e-5
|
||||
assert compare_torch_tf(outputs_torch[2][:, 50, :], outputs_tf[2][:, 50, :]) < 1e-5
|
||||
assert compare_torch_tf(outputs_torch[0], outputs_tf[0]) < 1e-4
|
||||
|
||||
# %%
|
||||
# save tf model
|
||||
save_checkpoint(model_tf, None, checkpoint['step'], checkpoint['epoch'],
|
||||
checkpoint['r'], args.output_path)
|
||||
print(' > Model conversion is successfully completed :).')
|
||||
save_checkpoint(model_tf, None, checkpoint["step"], checkpoint["epoch"], checkpoint["r"], args.output_path)
|
||||
print(" > Model conversion is successfully completed :).")
|
||||
|
|
|
@ -15,26 +15,19 @@ def main():
|
|||
Call train.py as a new process and pass command arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--script", type=str, help="Target training script to distibute.")
|
||||
parser.add_argument(
|
||||
'--script',
|
||||
type=str,
|
||||
help='Target training script to distibute.')
|
||||
parser.add_argument(
|
||||
'--continue_path',
|
||||
"--continue_path",
|
||||
type=str,
|
||||
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
||||
default='',
|
||||
required='--config_path' not in sys.argv)
|
||||
default="",
|
||||
required="--config_path" not in sys.argv,
|
||||
)
|
||||
parser.add_argument(
|
||||
'--restore_path',
|
||||
type=str,
|
||||
help='Model file to be restored. Use to finetune a model.',
|
||||
default='')
|
||||
"--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default=""
|
||||
)
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
type=str,
|
||||
help='Path to config file for training.',
|
||||
required='--continue_path' not in sys.argv
|
||||
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in sys.argv
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -44,20 +37,20 @@ def main():
|
|||
# set arguments for train.py
|
||||
folder_path = pathlib.Path(__file__).parent.absolute()
|
||||
command = [os.path.join(folder_path, args.script)]
|
||||
command.append('--continue_path={}'.format(args.continue_path))
|
||||
command.append('--restore_path={}'.format(args.restore_path))
|
||||
command.append('--config_path={}'.format(args.config_path))
|
||||
command.append('--group_id=group_{}'.format(group_id))
|
||||
command.append('')
|
||||
command.append("--continue_path={}".format(args.continue_path))
|
||||
command.append("--restore_path={}".format(args.restore_path))
|
||||
command.append("--config_path={}".format(args.config_path))
|
||||
command.append("--group_id=group_{}".format(group_id))
|
||||
command.append("")
|
||||
|
||||
# run processes
|
||||
processes = []
|
||||
for i in range(num_gpus):
|
||||
my_env = os.environ.copy()
|
||||
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
|
||||
command[-1] = '--rank={}'.format(i)
|
||||
stdout = None if i == 0 else open(os.devnull, 'w')
|
||||
p = subprocess.Popen(['python3'] + command, stdout=stdout, env=my_env)
|
||||
command[-1] = "--rank={}".format(i)
|
||||
stdout = None if i == 0 else open(os.devnull, "w")
|
||||
p = subprocess.Popen(["python3"] + command, stdout=stdout, env=my_env)
|
||||
processes.append(p)
|
||||
print(command)
|
||||
|
||||
|
@ -65,5 +58,5 @@ def main():
|
|||
p.wait()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -7,30 +7,24 @@ from TTS.tts.datasets.preprocess import get_preprocessor_by_name
|
|||
|
||||
|
||||
def main():
|
||||
# pylint: disable=bad-continuation
|
||||
parser = argparse.ArgumentParser(description='''Find all the unique characters or phonemes in a dataset.\n\n'''
|
||||
|
||||
'''Target dataset must be defined in TTS.tts.datasets.preprocess\n\n'''\
|
||||
'''
|
||||
# pylint: disable=bad-option-value
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Find all the unique characters or phonemes in a dataset.\n\n"""
|
||||
"""Target dataset must be defined in TTS.tts.datasets.preprocess\n\n"""
|
||||
"""
|
||||
Example runs:
|
||||
|
||||
python TTS/bin/find_unique_chars.py --dataset ljspeech --meta_file /path/to/LJSpeech/metadata.csv
|
||||
''', formatter_class=RawTextHelpFormatter)
|
||||
|
||||
parser.add_argument(
|
||||
'--dataset',
|
||||
type=str,
|
||||
default='',
|
||||
help='One of the target dataset names in TTS.tts.datasets.preprocess.'
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--meta_file',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to the transcriptions file of the dataset.'
|
||||
""",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset", type=str, default="", help="One of the target dataset names in TTS.tts.datasets.preprocess."
|
||||
)
|
||||
|
||||
parser.add_argument("--meta_file", type=str, default=None, help="Path to the transcriptions file of the dataset.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
preprocessor = get_preprocessor_by_name(args.dataset)
|
||||
|
|
|
@ -7,15 +7,17 @@ from argparse import RawTextHelpFormatter
|
|||
from multiprocessing import Pool
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def resample_file(func_args):
|
||||
filename, output_sr = func_args
|
||||
y, sr = librosa.load(filename, sr=output_sr)
|
||||
librosa.output.write_wav(filename, y, sr)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description='''Resample a folder recusively with librosa
|
||||
description="""Resample a folder recusively with librosa
|
||||
Can be used in place or create a copy of the folder as an output.\n\n
|
||||
Example run:
|
||||
python TTS/bin/resample.py
|
||||
|
@ -23,46 +25,52 @@ if __name__ == '__main__':
|
|||
--output_sr 22050
|
||||
--output_dir /root/resampled_LJSpeech-1.1/
|
||||
--n_jobs 24
|
||||
''',
|
||||
formatter_class=RawTextHelpFormatter)
|
||||
""",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument('--input_dir',
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help='Path of the folder containing the audio files to resample')
|
||||
parser.add_argument(
|
||||
"--input_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path of the folder containing the audio files to resample",
|
||||
)
|
||||
|
||||
parser.add_argument('--output_sr',
|
||||
type=int,
|
||||
default=22050,
|
||||
required=False,
|
||||
help='Samlple rate to which the audio files should be resampled')
|
||||
parser.add_argument(
|
||||
"--output_sr",
|
||||
type=int,
|
||||
default=22050,
|
||||
required=False,
|
||||
help="Samlple rate to which the audio files should be resampled",
|
||||
)
|
||||
|
||||
parser.add_argument('--output_dir',
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help='Path of the destination folder. If not defined, the operation is done in place')
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Path of the destination folder. If not defined, the operation is done in place",
|
||||
)
|
||||
|
||||
parser.add_argument('--n_jobs',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Number of threads to use, by default it uses all cores')
|
||||
parser.add_argument(
|
||||
"--n_jobs", type=int, default=None, help="Number of threads to use, by default it uses all cores"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.output_dir:
|
||||
print('Recursively copying the input folder...')
|
||||
print("Recursively copying the input folder...")
|
||||
copy_tree(args.input_dir, args.output_dir)
|
||||
args.input_dir = args.output_dir
|
||||
|
||||
print('Resampling the audio files...')
|
||||
audio_files = glob.glob(os.path.join(args.input_dir, '**/*.wav'), recursive=True)
|
||||
print(f'Found {len(audio_files)} files...')
|
||||
audio_files = list(zip(audio_files, len(audio_files)*[args.output_sr]))
|
||||
print("Resampling the audio files...")
|
||||
audio_files = glob.glob(os.path.join(args.input_dir, "**/*.wav"), recursive=True)
|
||||
print(f"Found {len(audio_files)} files...")
|
||||
audio_files = list(zip(audio_files, len(audio_files) * [args.output_sr]))
|
||||
with Pool(processes=args.n_jobs) as p:
|
||||
with tqdm(total=len(audio_files)) as pbar:
|
||||
for i, _ in enumerate(p.imap_unordered(resample_file, audio_files)):
|
||||
pbar.update()
|
||||
|
||||
print('Done !')
|
||||
print("Done !")
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
import argparse
|
||||
import sys
|
||||
from argparse import RawTextHelpFormatter
|
||||
|
||||
# pylint: disable=redefined-outer-name, unused-argument
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -14,22 +15,20 @@ from TTS.utils.synthesizer import Synthesizer
|
|||
def str2bool(v):
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
||||
if v.lower() in ("yes", "true", "t", "y", "1"):
|
||||
return True
|
||||
if v.lower() in ('no', 'false', 'f', 'n', '0'):
|
||||
if v.lower() in ("no", "false", "f", "n", "0"):
|
||||
return False
|
||||
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||||
raise argparse.ArgumentTypeError("Boolean value expected.")
|
||||
|
||||
|
||||
def main():
|
||||
# pylint: disable=bad-continuation
|
||||
parser = argparse.ArgumentParser(description='''Synthesize speech on command line.\n\n'''
|
||||
|
||||
'''You can either use your trained model or choose a model from the provided list.\n\n'''\
|
||||
|
||||
'''If you don't specify any models, then it uses LJSpeech based English models\n\n'''\
|
||||
|
||||
'''
|
||||
# pylint: disable=bad-option-value
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""Synthesize speech on command line.\n\n"""
|
||||
"""You can either use your trained model or choose a model from the provided list.\n\n"""
|
||||
"""If you don't specify any models, then it uses LJSpeech based English models\n\n"""
|
||||
"""
|
||||
Example runs:
|
||||
|
||||
# list provided models
|
||||
|
@ -51,106 +50,80 @@ def main():
|
|||
./TTS/bin/synthesize.py --text "Text for TTS" --model_path path/to/config.json --config_path path/to/model.pth.tar --out_path output/path/speech.wav
|
||||
--vocoder_path path/to/vocoder.pth.tar --vocoder_config_path path/to/vocoder_config.json
|
||||
|
||||
''',
|
||||
formatter_class=RawTextHelpFormatter)
|
||||
""",
|
||||
formatter_class=RawTextHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--list_models',
|
||||
"--list_models",
|
||||
type=str2bool,
|
||||
nargs='?',
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
help='list available pre-trained tts and vocoder models.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--text',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Text to generate speech.'
|
||||
)
|
||||
help="list available pre-trained tts and vocoder models.",
|
||||
)
|
||||
parser.add_argument("--text", type=str, default=None, help="Text to generate speech.")
|
||||
|
||||
# Args for running pre-trained TTS models.
|
||||
parser.add_argument(
|
||||
'--model_name',
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="tts_models/en/ljspeech/speedy-speech-wn",
|
||||
help=
|
||||
'Name of one of the pre-trained tts models in format <language>/<dataset>/<model_name>'
|
||||
help="Name of one of the pre-trained tts models in format <language>/<dataset>/<model_name>",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--vocoder_name',
|
||||
"--vocoder_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
'Name of one of the pre-trained vocoder models in format <language>/<dataset>/<model_name>'
|
||||
help="Name of one of the pre-trained vocoder models in format <language>/<dataset>/<model_name>",
|
||||
)
|
||||
|
||||
# Args for running custom models
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to model config file.")
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
default=None,
|
||||
type=str,
|
||||
help='Path to model config file.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model_path',
|
||||
"--model_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help='Path to model file.',
|
||||
help="Path to model file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
'--out_path',
|
||||
"--out_path",
|
||||
type=str,
|
||||
default='tts_output.wav',
|
||||
help='Output wav file path.',
|
||||
default="tts_output.wav",
|
||||
help="Output wav file path.",
|
||||
)
|
||||
parser.add_argument("--use_cuda", type=bool, help="Run model on CUDA.", default=False)
|
||||
parser.add_argument(
|
||||
'--use_cuda',
|
||||
type=bool,
|
||||
help='Run model on CUDA.',
|
||||
default=False
|
||||
)
|
||||
parser.add_argument(
|
||||
'--vocoder_path',
|
||||
"--vocoder_path",
|
||||
type=str,
|
||||
help=
|
||||
'Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).',
|
||||
help="Path to vocoder model file. If it is not defined, model uses GL as vocoder. Please make sure that you installed vocoder library before (WaveRNN).",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
'--vocoder_config_path',
|
||||
type=str,
|
||||
help='Path to vocoder model config file.',
|
||||
default=None)
|
||||
parser.add_argument("--vocoder_config_path", type=str, help="Path to vocoder model config file.", default=None)
|
||||
|
||||
# args for multi-speaker synthesis
|
||||
parser.add_argument("--speakers_json", type=str, help="JSON file for multi-speaker model.", default=None)
|
||||
parser.add_argument(
|
||||
'--speakers_json',
|
||||
type=str,
|
||||
help="JSON file for multi-speaker model.",
|
||||
default=None)
|
||||
parser.add_argument(
|
||||
'--speaker_idx',
|
||||
"--speaker_idx",
|
||||
type=str,
|
||||
help="if the tts model is trained with x-vectors, then speaker_idx is a file present in speakers.json else speaker_idx is the speaker id corresponding to a speaker in the speaker embedding layer.",
|
||||
default=None)
|
||||
parser.add_argument(
|
||||
'--gst_style',
|
||||
help="Wav path file for GST stylereference.",
|
||||
default=None)
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--gst_style", help="Wav path file for GST stylereference.", default=None)
|
||||
|
||||
# aux args
|
||||
parser.add_argument(
|
||||
'--save_spectogram',
|
||||
"--save_spectogram",
|
||||
type=bool,
|
||||
help="If true save raw spectogram for further (vocoder) processing in out_path.",
|
||||
default=False)
|
||||
default=False,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# print the description if either text or list_models is not set
|
||||
if args.text is None and not args.list_models:
|
||||
parser.parse_args(['-h'])
|
||||
parser.parse_args(["-h"])
|
||||
|
||||
# load model manager
|
||||
path = Path(__file__).parent / "../.models.json"
|
||||
|
@ -169,7 +142,7 @@ def main():
|
|||
# CASE2: load pre-trained models
|
||||
if args.model_name is not None:
|
||||
model_path, config_path, model_item = manager.download_model(args.model_name)
|
||||
args.vocoder_name = model_item['default_vocoder'] if args.vocoder_name is None else args.vocoder_name
|
||||
args.vocoder_name = model_item["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
|
||||
|
||||
if args.vocoder_name is not None:
|
||||
vocoder_path, vocoder_config_path, _ = manager.download_model(args.vocoder_name)
|
||||
|
|
|
@ -25,12 +25,11 @@ from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
|||
from TTS.utils.arguments import parse_arguments, process_args
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.distribute import init_distributed, reduce_tensor
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.training import NoamLR, setup_torch_training_env
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||
# torch.autograd.set_detect_anomaly(True)
|
||||
|
||||
|
@ -44,10 +43,9 @@ if __name__ == '__main__':
|
|||
compute_linear_spec=False,
|
||||
meta_data=meta_data_eval if is_val else meta_data_train,
|
||||
ap=ap,
|
||||
tp=c.characters if 'characters' in c.keys() else None,
|
||||
add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
|
||||
batch_group_size=0 if is_val else c.batch_group_size *
|
||||
c.batch_size,
|
||||
tp=c.characters if "characters" in c.keys() else None,
|
||||
add_blank=c["add_blank"] if "add_blank" in c.keys() else False,
|
||||
batch_group_size=0 if is_val else c.batch_group_size * c.batch_size,
|
||||
min_seq_len=c.min_seq_len,
|
||||
max_seq_len=c.max_seq_len,
|
||||
phoneme_cache_path=c.phoneme_cache_path,
|
||||
|
@ -56,8 +54,10 @@ if __name__ == '__main__':
|
|||
enable_eos_bos=c.enable_eos_bos_chars,
|
||||
use_noise_augment=not is_val,
|
||||
verbose=verbose,
|
||||
speaker_mapping=speaker_mapping if c.use_speaker_embedding
|
||||
and c.use_external_speaker_embedding_file else None)
|
||||
speaker_mapping=speaker_mapping
|
||||
if c.use_speaker_embedding and c.use_external_speaker_embedding_file
|
||||
else None,
|
||||
)
|
||||
|
||||
if c.use_phonemes and c.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
|
@ -72,9 +72,9 @@ if __name__ == '__main__':
|
|||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=c.num_val_loader_workers
|
||||
if is_val else c.num_loader_workers,
|
||||
pin_memory=False)
|
||||
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
def format_data(data):
|
||||
|
@ -94,10 +94,7 @@ if __name__ == '__main__':
|
|||
speaker_c = data[8]
|
||||
else:
|
||||
# return speaker_id to be used by an embedding layer
|
||||
speaker_c = [
|
||||
speaker_mapping[speaker_name]
|
||||
for speaker_name in speaker_names
|
||||
]
|
||||
speaker_c = [speaker_mapping[speaker_name] for speaker_name in speaker_names]
|
||||
speaker_c = torch.LongTensor(speaker_c)
|
||||
else:
|
||||
speaker_c = None
|
||||
|
@ -109,18 +106,15 @@ if __name__ == '__main__':
|
|||
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
||||
if speaker_c is not None:
|
||||
speaker_c = speaker_c.cuda(non_blocking=True)
|
||||
return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
|
||||
avg_text_length, avg_spec_length, item_idx
|
||||
return text_input, text_lengths, mel_input, mel_lengths, speaker_c, avg_text_length, avg_spec_length, item_idx
|
||||
|
||||
def train(data_loader, model, criterion, optimizer, scheduler, ap,
|
||||
global_step, epoch, training_phase):
|
||||
def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch, training_phase):
|
||||
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(
|
||||
len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
end_time = time.time()
|
||||
|
@ -130,8 +124,16 @@ if __name__ == '__main__':
|
|||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\
|
||||
avg_text_length, avg_spec_length, _ = format_data(data)
|
||||
(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_targets,
|
||||
mel_lengths,
|
||||
speaker_c,
|
||||
avg_text_length,
|
||||
avg_spec_length,
|
||||
_,
|
||||
) = format_data(data)
|
||||
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
|
@ -141,36 +143,32 @@ if __name__ == '__main__':
|
|||
# forward pass model
|
||||
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
||||
decoder_output, dur_output, dur_mas_output, alignments, _, _, logp = model.forward(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_targets,
|
||||
mel_lengths,
|
||||
g=speaker_c,
|
||||
phase=training_phase)
|
||||
text_input, text_lengths, mel_targets, mel_lengths, g=speaker_c, phase=training_phase
|
||||
)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(logp,
|
||||
decoder_output,
|
||||
mel_targets,
|
||||
mel_lengths,
|
||||
dur_output,
|
||||
dur_mas_output,
|
||||
text_lengths,
|
||||
global_step,
|
||||
phase=training_phase)
|
||||
loss_dict = criterion(
|
||||
logp,
|
||||
decoder_output,
|
||||
mel_targets,
|
||||
mel_lengths,
|
||||
dur_output,
|
||||
dur_mas_output,
|
||||
text_lengths,
|
||||
global_step,
|
||||
phase=training_phase,
|
||||
)
|
||||
|
||||
# backward pass with loss scaling
|
||||
if c.mixed_precision:
|
||||
scaler.scale(loss_dict['loss']).backward()
|
||||
scaler.scale(loss_dict["loss"]).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), c.grad_clip)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss_dict['loss'].backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), c.grad_clip)
|
||||
loss_dict["loss"].backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
# setup lr
|
||||
|
@ -178,25 +176,21 @@ if __name__ == '__main__':
|
|||
scheduler.step()
|
||||
|
||||
# current_lr
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
current_lr = optimizer.param_groups[0]["lr"]
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
|
||||
loss_dict['align_error'] = align_error
|
||||
loss_dict["align_error"] = align_error
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data,
|
||||
num_gpus)
|
||||
loss_dict['loss_ssim'] = reduce_tensor(
|
||||
loss_dict['loss_ssim'].data, num_gpus)
|
||||
loss_dict['loss_dur'] = reduce_tensor(
|
||||
loss_dict['loss_dur'].data, num_gpus)
|
||||
loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data,
|
||||
num_gpus)
|
||||
loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus)
|
||||
loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus)
|
||||
loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus)
|
||||
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
|
@ -210,48 +204,43 @@ if __name__ == '__main__':
|
|||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values['avg_' + key] = value
|
||||
update_train_values['avg_loader_time'] = loader_time
|
||||
update_train_values['avg_step_time'] = step_time
|
||||
update_train_values["avg_" + key] = value
|
||||
update_train_values["avg_loader_time"] = loader_time
|
||||
update_train_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
# print training progress
|
||||
if global_step % c.print_step == 0:
|
||||
log_dict = {
|
||||
"avg_spec_length": [avg_spec_length,
|
||||
1], # value, precision
|
||||
"avg_spec_length": [avg_spec_length, 1], # value, precision
|
||||
"avg_text_length": [avg_text_length, 1],
|
||||
"step_time": [step_time, 4],
|
||||
"loader_time": [loader_time, 2],
|
||||
"current_lr": current_lr,
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
||||
log_dict, loss_dict,
|
||||
keep_avg.avg_values)
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# Plot Training Iter Stats
|
||||
# reduce TB load
|
||||
if global_step % c.tb_plot_step == 0:
|
||||
iter_stats = {
|
||||
"lr": current_lr,
|
||||
"grad_norm": grad_norm,
|
||||
"step_time": step_time
|
||||
}
|
||||
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
|
||||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model,
|
||||
optimizer,
|
||||
global_step,
|
||||
epoch,
|
||||
1,
|
||||
OUT_PATH,
|
||||
model_characters,
|
||||
model_loss=loss_dict['loss'])
|
||||
save_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
global_step,
|
||||
epoch,
|
||||
1,
|
||||
OUT_PATH,
|
||||
model_characters,
|
||||
model_loss=loss_dict["loss"],
|
||||
)
|
||||
|
||||
# wait all kernels to be completed
|
||||
torch.cuda.synchronize()
|
||||
|
@ -259,8 +248,7 @@ if __name__ == '__main__':
|
|||
# Diagnostic visualizations
|
||||
if decoder_output is not None:
|
||||
idx = np.random.randint(mel_targets.shape[0])
|
||||
pred_spec = decoder_output[idx].detach().data.cpu(
|
||||
).numpy().T
|
||||
pred_spec = decoder_output[idx].detach().data.cpu().numpy().T
|
||||
gt_spec = mel_targets[idx].data.cpu().numpy().T
|
||||
align_img = alignments[idx].data.cpu()
|
||||
|
||||
|
@ -274,14 +262,11 @@ if __name__ == '__main__':
|
|||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
tb_logger.tb_train_audios(global_step,
|
||||
{'TrainAudio': train_audio},
|
||||
c.audio["sample_rate"])
|
||||
tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, c.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
# print epoch stats
|
||||
c_logger.print_train_epoch_end(global_step, epoch, epoch_time,
|
||||
keep_avg)
|
||||
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
||||
|
||||
# Plot Epoch Stats
|
||||
if args.rank == 0:
|
||||
|
@ -293,8 +278,7 @@ if __name__ == '__main__':
|
|||
return keep_avg.avg_values, global_step
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(data_loader, model, criterion, ap, global_step, epoch,
|
||||
training_phase):
|
||||
def evaluate(data_loader, model, criterion, ap, global_step, epoch, training_phase):
|
||||
model.eval()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
|
@ -304,50 +288,41 @@ if __name__ == '__main__':
|
|||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\
|
||||
_, _, _ = format_data(data)
|
||||
text_input, text_lengths, mel_targets, mel_lengths, speaker_c, _, _, _ = format_data(data)
|
||||
|
||||
# forward pass model
|
||||
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
||||
decoder_output, dur_output, dur_mas_output, alignments, _, _, logp = model.forward(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_targets,
|
||||
mel_lengths,
|
||||
g=speaker_c,
|
||||
phase=training_phase)
|
||||
text_input, text_lengths, mel_targets, mel_lengths, g=speaker_c, phase=training_phase
|
||||
)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(logp,
|
||||
decoder_output,
|
||||
mel_targets,
|
||||
mel_lengths,
|
||||
dur_output,
|
||||
dur_mas_output,
|
||||
text_lengths,
|
||||
global_step,
|
||||
phase=training_phase)
|
||||
|
||||
loss_dict = criterion(
|
||||
logp,
|
||||
decoder_output,
|
||||
mel_targets,
|
||||
mel_lengths,
|
||||
dur_output,
|
||||
dur_mas_output,
|
||||
text_lengths,
|
||||
global_step,
|
||||
phase=training_phase,
|
||||
)
|
||||
|
||||
# step time
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# compute alignment score
|
||||
align_error = 1 - alignment_diagonal_score(alignments,
|
||||
binary=True)
|
||||
loss_dict['align_error'] = align_error
|
||||
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
|
||||
loss_dict["align_error"] = align_error
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict['loss_l1'] = reduce_tensor(
|
||||
loss_dict['loss_l1'].data, num_gpus)
|
||||
loss_dict['loss_ssim'] = reduce_tensor(
|
||||
loss_dict['loss_ssim'].data, num_gpus)
|
||||
loss_dict['loss_dur'] = reduce_tensor(
|
||||
loss_dict['loss_dur'].data, num_gpus)
|
||||
loss_dict['loss'] = reduce_tensor(loss_dict['loss'].data,
|
||||
num_gpus)
|
||||
loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus)
|
||||
loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus)
|
||||
loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus)
|
||||
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
|
@ -361,12 +336,11 @@ if __name__ == '__main__':
|
|||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values['avg_' + key] = value
|
||||
update_train_values["avg_" + key] = value
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
if c.print_eval:
|
||||
c_logger.print_eval_step(num_iter, loss_dict,
|
||||
keep_avg.avg_values)
|
||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# Diagnostic visualizations
|
||||
|
@ -376,19 +350,14 @@ if __name__ == '__main__':
|
|||
align_img = alignments[idx].data.cpu()
|
||||
|
||||
eval_figures = {
|
||||
"prediction": plot_spectrogram(pred_spec,
|
||||
ap,
|
||||
output_fig=False),
|
||||
"ground_truth": plot_spectrogram(gt_spec,
|
||||
ap,
|
||||
output_fig=False),
|
||||
"alignment": plot_alignment(align_img, output_fig=False)
|
||||
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||
"alignment": plot_alignment(align_img, output_fig=False),
|
||||
}
|
||||
|
||||
# Sample audio
|
||||
eval_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
|
||||
c.audio["sample_rate"])
|
||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"])
|
||||
|
||||
# Plot Validation Stats
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
|
@ -401,7 +370,7 @@ if __name__ == '__main__':
|
|||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963."
|
||||
"Prior to November 22, 1963.",
|
||||
]
|
||||
else:
|
||||
with open(c.test_sentences_file, "r") as f:
|
||||
|
@ -413,9 +382,9 @@ if __name__ == '__main__':
|
|||
print(" | > Synthesizing test sentences")
|
||||
if c.use_speaker_embedding:
|
||||
if c.use_external_speaker_embedding_file:
|
||||
speaker_embedding = speaker_mapping[list(
|
||||
speaker_mapping.keys())[randrange(
|
||||
len(speaker_mapping) - 1)]]['embedding']
|
||||
speaker_embedding = speaker_mapping[
|
||||
list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]
|
||||
]["embedding"]
|
||||
speaker_id = None
|
||||
else:
|
||||
speaker_id = 0
|
||||
|
@ -437,25 +406,22 @@ if __name__ == '__main__':
|
|||
speaker_embedding=speaker_embedding,
|
||||
style_wav=style_wav,
|
||||
truncated=False,
|
||||
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
|
||||
enable_eos_bos_chars=c.enable_eos_bos_chars, # pylint: disable=unused-argument
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False)
|
||||
do_trim_silence=False,
|
||||
)
|
||||
|
||||
file_path = os.path.join(AUDIO_PATH, str(global_step))
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
file_path = os.path.join(file_path,
|
||||
"TestSentence_{}.wav".format(idx))
|
||||
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
|
||||
ap.save_wav(wav, file_path)
|
||||
test_audios['{}-audio'.format(idx)] = wav
|
||||
test_figures['{}-prediction'.format(
|
||||
idx)] = plot_spectrogram(postnet_output, ap)
|
||||
test_figures['{}-alignment'.format(idx)] = plot_alignment(
|
||||
alignment)
|
||||
except: #pylint: disable=bare-except
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment)
|
||||
except: # pylint: disable=bare-except
|
||||
print(" !! Error creating Test Sentence -", idx)
|
||||
traceback.print_exc()
|
||||
tb_logger.tb_test_audios(global_step, test_audios,
|
||||
c.audio['sample_rate'])
|
||||
tb_logger.tb_test_audios(global_step, test_audios, c.audio["sample_rate"])
|
||||
tb_logger.tb_test_figures(global_step, test_figures)
|
||||
return keep_avg.avg_values
|
||||
|
||||
|
@ -464,69 +430,55 @@ if __name__ == '__main__':
|
|||
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
if 'characters' in c.keys():
|
||||
if "characters" in c.keys():
|
||||
symbols, phonemes = make_symbols(**c.characters)
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id,
|
||||
c.distributed["backend"], c.distributed["url"])
|
||||
init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"])
|
||||
|
||||
# set model characters
|
||||
model_characters = phonemes if c.use_phonemes else symbols
|
||||
num_chars = len(model_characters)
|
||||
|
||||
# load data instances
|
||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets,
|
||||
eval_split=True)
|
||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True)
|
||||
|
||||
# set the portion of the data used for training if set in config.json
|
||||
if 'train_portion' in c.keys():
|
||||
meta_data_train = meta_data_train[:int(
|
||||
len(meta_data_train) * c.train_portion)]
|
||||
if 'eval_portion' in c.keys():
|
||||
meta_data_eval = meta_data_eval[:int(
|
||||
len(meta_data_eval) * c.eval_portion)]
|
||||
if "train_portion" in c.keys():
|
||||
meta_data_train = meta_data_train[: int(len(meta_data_train) * c.train_portion)]
|
||||
if "eval_portion" in c.keys():
|
||||
meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * c.eval_portion)]
|
||||
|
||||
# parse speakers
|
||||
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(
|
||||
c, args, meta_data_train, OUT_PATH)
|
||||
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH)
|
||||
|
||||
# setup model
|
||||
model = setup_model(num_chars,
|
||||
num_speakers,
|
||||
c,
|
||||
speaker_embedding_dim=speaker_embedding_dim)
|
||||
optimizer = RAdam(model.parameters(),
|
||||
lr=c.lr,
|
||||
weight_decay=0,
|
||||
betas=(0.9, 0.98),
|
||||
eps=1e-9)
|
||||
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim)
|
||||
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
|
||||
criterion = AlignTTSLoss(c)
|
||||
|
||||
if args.restore_path:
|
||||
print(
|
||||
f" > Restoring from {os.path.basename(args.restore_path)} ...")
|
||||
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||
print(f" > Restoring from {os.path.basename(args.restore_path)} ...")
|
||||
checkpoint = torch.load(args.restore_path, map_location="cpu")
|
||||
try:
|
||||
# TODO: fix optimizer init, model.cuda() needs to be called before
|
||||
# optimizer restore
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
if c.reinit_layers:
|
||||
raise RuntimeError
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
except: #pylint: disable=bare-except
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
except: # pylint: disable=bare-except
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group['initial_lr'] = c.lr
|
||||
print(" > Model restored from step %d" % checkpoint['step'],
|
||||
flush=True)
|
||||
args.restore_step = checkpoint['step']
|
||||
group["initial_lr"] = c.lr
|
||||
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
|
@ -539,9 +491,7 @@ if __name__ == '__main__':
|
|||
model = DDP_th(model, device_ids=[args.rank])
|
||||
|
||||
if c.noam_schedule:
|
||||
scheduler = NoamLR(optimizer,
|
||||
warmup_steps=c.warmup_steps,
|
||||
last_epoch=args.restore_step - 1)
|
||||
scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
|
@ -549,16 +499,14 @@ if __name__ == '__main__':
|
|||
print("\n > Model has {} parameters".format(num_params), flush=True)
|
||||
|
||||
if args.restore_step == 0 or not args.best_path:
|
||||
best_loss = float('inf')
|
||||
best_loss = float("inf")
|
||||
print(" > Starting with inf best loss.")
|
||||
else:
|
||||
print(" > Restoring best loss from "
|
||||
f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path,
|
||||
map_location='cpu')['model_loss']
|
||||
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {best_loss}.")
|
||||
keep_all_best = c.get('keep_all_best', False)
|
||||
keep_after = c.get('keep_after', 10000) # void if keep_all_best False
|
||||
keep_all_best = c.get("keep_all_best", False)
|
||||
keep_after = c.get("keep_after", 10000) # void if keep_all_best False
|
||||
|
||||
# define dataloaders
|
||||
train_loader = setup_loader(ap, 1, is_val=False, verbose=True)
|
||||
|
@ -573,9 +521,9 @@ if __name__ == '__main__':
|
|||
if not True in vals:
|
||||
phase = 0
|
||||
else:
|
||||
phase = len(c.phase_start_steps) - [
|
||||
i < global_step for i in c.phase_start_steps
|
||||
][::-1].index(True) - 1
|
||||
phase = (
|
||||
len(c.phase_start_steps) - [i < global_step for i in c.phase_start_steps][::-1].index(True) - 1
|
||||
)
|
||||
else:
|
||||
phase = None
|
||||
return phase
|
||||
|
@ -584,32 +532,30 @@ if __name__ == '__main__':
|
|||
cur_phase = set_phase()
|
||||
print(f"\n > Current AlignTTS phase: {cur_phase}")
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
train_avg_loss_dict, global_step = train(train_loader, model,
|
||||
criterion, optimizer,
|
||||
scheduler, ap,
|
||||
global_step, epoch,
|
||||
cur_phase)
|
||||
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
|
||||
global_step, epoch, cur_phase)
|
||||
train_avg_loss_dict, global_step = train(
|
||||
train_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch, cur_phase
|
||||
)
|
||||
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch, cur_phase)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = train_avg_loss_dict['avg_loss']
|
||||
target_loss = train_avg_loss_dict["avg_loss"]
|
||||
if c.run_eval:
|
||||
target_loss = eval_avg_loss_dict['avg_loss']
|
||||
best_loss = save_best_model(target_loss,
|
||||
best_loss,
|
||||
model,
|
||||
optimizer,
|
||||
global_step,
|
||||
epoch,
|
||||
1,
|
||||
OUT_PATH,
|
||||
model_characters,
|
||||
keep_all_best=keep_all_best,
|
||||
keep_after=keep_after)
|
||||
target_loss = eval_avg_loss_dict["avg_loss"]
|
||||
best_loss = save_best_model(
|
||||
target_loss,
|
||||
best_loss,
|
||||
model,
|
||||
optimizer,
|
||||
global_step,
|
||||
epoch,
|
||||
1,
|
||||
OUT_PATH,
|
||||
model_characters,
|
||||
keep_all_best=keep_all_best,
|
||||
keep_after=keep_after,
|
||||
)
|
||||
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, model_class='tts')
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts")
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -12,14 +12,17 @@ from torch.utils.data import DataLoader
|
|||
from TTS.speaker_encoder.dataset import MyDataset
|
||||
from TTS.speaker_encoder.losses import AngleProtoLoss, GE2ELoss
|
||||
from TTS.speaker_encoder.model import SpeakerEncoder
|
||||
from TTS.speaker_encoder.utils.generic_utils import \
|
||||
check_config_speaker_encoder, save_best_model
|
||||
from TTS.speaker_encoder.utils.generic_utils import check_config_speaker_encoder, save_best_model
|
||||
from TTS.speaker_encoder.utils.visual import plot_embeddings
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import (count_parameters,
|
||||
create_experiment_folder, get_git_branch,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.generic_utils import (
|
||||
count_parameters,
|
||||
create_experiment_folder,
|
||||
get_git_branch,
|
||||
remove_experiment_folder,
|
||||
set_init_dict,
|
||||
)
|
||||
from TTS.utils.io import copy_model_files, load_config
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
|
@ -34,28 +37,30 @@ print(" > Using CUDA: ", use_cuda)
|
|||
print(" > Number of GPUs: ", num_gpus)
|
||||
|
||||
|
||||
def setup_loader(ap: AudioProcessor,
|
||||
is_val: bool = False,
|
||||
verbose: bool = False):
|
||||
def setup_loader(ap: AudioProcessor, is_val: bool = False, verbose: bool = False):
|
||||
if is_val:
|
||||
loader = None
|
||||
else:
|
||||
dataset = MyDataset(ap,
|
||||
meta_data_eval if is_val else meta_data_train,
|
||||
voice_len=1.6,
|
||||
num_utter_per_speaker=c.num_utters_per_speaker,
|
||||
num_speakers_in_batch=c.num_speakers_in_batch,
|
||||
skip_speakers=False,
|
||||
storage_size=c.storage["storage_size"],
|
||||
sample_from_storage_p=c.storage["sample_from_storage_p"],
|
||||
additive_noise=c.storage["additive_noise"],
|
||||
verbose=verbose)
|
||||
dataset = MyDataset(
|
||||
ap,
|
||||
meta_data_eval if is_val else meta_data_train,
|
||||
voice_len=1.6,
|
||||
num_utter_per_speaker=c.num_utters_per_speaker,
|
||||
num_speakers_in_batch=c.num_speakers_in_batch,
|
||||
skip_speakers=False,
|
||||
storage_size=c.storage["storage_size"],
|
||||
sample_from_storage_p=c.storage["sample_from_storage_p"],
|
||||
additive_noise=c.storage["additive_noise"],
|
||||
verbose=verbose,
|
||||
)
|
||||
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(dataset,
|
||||
batch_size=c.num_speakers_in_batch,
|
||||
shuffle=False,
|
||||
num_workers=c.num_loader_workers,
|
||||
collate_fn=dataset.collate_fn)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=c.num_speakers_in_batch,
|
||||
shuffle=False,
|
||||
num_workers=c.num_loader_workers,
|
||||
collate_fn=dataset.collate_fn,
|
||||
)
|
||||
return loader
|
||||
|
||||
|
||||
|
@ -63,7 +68,7 @@ def train(model, criterion, optimizer, scheduler, ap, global_step):
|
|||
data_loader = setup_loader(ap, is_val=False, verbose=True)
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
best_loss = float('inf')
|
||||
best_loss = float("inf")
|
||||
avg_loss = 0
|
||||
avg_loader_time = 0
|
||||
end_time = time.time()
|
||||
|
@ -89,9 +94,7 @@ def train(model, criterion, optimizer, scheduler, ap, global_step):
|
|||
outputs = model(inputs)
|
||||
|
||||
# loss computation
|
||||
loss = criterion(
|
||||
outputs.view(c.num_speakers_in_batch,
|
||||
outputs.shape[0] // c.num_speakers_in_batch, -1))
|
||||
loss = criterion(outputs.view(c.num_speakers_in_batch, outputs.shape[0] // c.num_speakers_in_batch, -1))
|
||||
loss.backward()
|
||||
grad_norm, _ = check_update(model, c.grad_clip)
|
||||
optimizer.step()
|
||||
|
@ -100,11 +103,13 @@ def train(model, criterion, optimizer, scheduler, ap, global_step):
|
|||
epoch_time += step_time
|
||||
|
||||
# Averaged Loss and Averaged Loader Time
|
||||
avg_loss = 0.01 * loss.item() \
|
||||
+ 0.99 * avg_loss if avg_loss != 0 else loss.item()
|
||||
avg_loader_time = 1/c.num_loader_workers * loader_time + \
|
||||
(c.num_loader_workers-1) / c.num_loader_workers * avg_loader_time if avg_loader_time != 0 else loader_time
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
avg_loss = 0.01 * loss.item() + 0.99 * avg_loss if avg_loss != 0 else loss.item()
|
||||
avg_loader_time = (
|
||||
1 / c.num_loader_workers * loader_time + (c.num_loader_workers - 1) / c.num_loader_workers * avg_loader_time
|
||||
if avg_loader_time != 0
|
||||
else loader_time
|
||||
)
|
||||
current_lr = optimizer.param_groups[0]["lr"]
|
||||
|
||||
if global_step % c.steps_plot_stats == 0:
|
||||
# Plot Training Epoch Stats
|
||||
|
@ -113,13 +118,12 @@ def train(model, criterion, optimizer, scheduler, ap, global_step):
|
|||
"lr": current_lr,
|
||||
"grad_norm": grad_norm,
|
||||
"step_time": step_time,
|
||||
"avg_loader_time": avg_loader_time
|
||||
"avg_loader_time": avg_loader_time,
|
||||
}
|
||||
tb_logger.tb_train_epoch_stats(global_step, train_stats)
|
||||
figures = {
|
||||
# FIXME: not constant
|
||||
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(),
|
||||
10),
|
||||
"UMAP Plot": plot_embeddings(outputs.detach().cpu().numpy(), 10),
|
||||
}
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
|
@ -127,13 +131,13 @@ def train(model, criterion, optimizer, scheduler, ap, global_step):
|
|||
print(
|
||||
" | > Step:{} Loss:{:.5f} AvgLoss:{:.5f} GradNorm:{:.5f} "
|
||||
"StepTime:{:.2f} LoaderTime:{:.2f} AvGLoaderTime:{:.2f} LR:{:.6f}".format(
|
||||
global_step, loss.item(), avg_loss, grad_norm, step_time,
|
||||
loader_time, avg_loader_time, current_lr),
|
||||
flush=True)
|
||||
global_step, loss.item(), avg_loss, grad_norm, step_time, loader_time, avg_loader_time, current_lr
|
||||
),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# save best model
|
||||
best_loss = save_best_model(model, optimizer, avg_loss, best_loss,
|
||||
OUT_PATH, global_step)
|
||||
best_loss = save_best_model(model, optimizer, avg_loss, best_loss, OUT_PATH, global_step)
|
||||
|
||||
end_time = time.time()
|
||||
return avg_loss, global_step
|
||||
|
@ -145,14 +149,16 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
global meta_data_eval
|
||||
|
||||
ap = AudioProcessor(**c.audio)
|
||||
model = SpeakerEncoder(input_dim=c.model['input_dim'],
|
||||
proj_dim=c.model['proj_dim'],
|
||||
lstm_dim=c.model['lstm_dim'],
|
||||
num_lstm_layers=c.model['num_lstm_layers'])
|
||||
model = SpeakerEncoder(
|
||||
input_dim=c.model["input_dim"],
|
||||
proj_dim=c.model["proj_dim"],
|
||||
lstm_dim=c.model["lstm_dim"],
|
||||
num_lstm_layers=c.model["num_lstm_layers"],
|
||||
)
|
||||
optimizer = RAdam(model.parameters(), lr=c.lr)
|
||||
|
||||
if c.loss == "ge2e":
|
||||
criterion = GE2ELoss(loss_method='softmax')
|
||||
criterion = GE2ELoss(loss_method="softmax")
|
||||
elif c.loss == "angleproto":
|
||||
criterion = AngleProtoLoss()
|
||||
else:
|
||||
|
@ -166,7 +172,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
# optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
if c.reinit_layers:
|
||||
raise RuntimeError
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
except KeyError:
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = model.state_dict()
|
||||
|
@ -174,10 +180,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
for group in optimizer.param_groups:
|
||||
group['lr'] = c.lr
|
||||
print(" > Model restored from step %d" % checkpoint['step'],
|
||||
flush=True)
|
||||
args.restore_step = checkpoint['step']
|
||||
group["lr"] = c.lr
|
||||
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
|
@ -186,9 +191,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
criterion.cuda()
|
||||
|
||||
if c.lr_decay:
|
||||
scheduler = NoamLR(optimizer,
|
||||
warmup_steps=c.warmup_steps,
|
||||
last_epoch=args.restore_step - 1)
|
||||
scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
|
@ -199,55 +202,39 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
||||
|
||||
global_step = args.restore_step
|
||||
_, global_step = train(model, criterion, optimizer, scheduler, ap,
|
||||
global_step)
|
||||
_, global_step = train(model, criterion, optimizer, scheduler, ap, global_step)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--restore_path',
|
||||
type=str,
|
||||
help='Path to model outputs (checkpoint, tensorboard etc.).',
|
||||
default=0)
|
||||
"--restore_path", type=str, help="Path to model outputs (checkpoint, tensorboard etc.).", default=0
|
||||
)
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
"--config_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help='Path to config file for training.',
|
||||
help="Path to config file for training.",
|
||||
)
|
||||
parser.add_argument('--debug',
|
||||
type=bool,
|
||||
default=True,
|
||||
help='Do not verify commit integrity to run training.')
|
||||
parser.add_argument(
|
||||
'--data_path',
|
||||
type=str,
|
||||
default='',
|
||||
help='Defines the data path. It overwrites config.json.')
|
||||
parser.add_argument('--output_path',
|
||||
type=str,
|
||||
help='path for training outputs.',
|
||||
default='')
|
||||
parser.add_argument('--output_folder',
|
||||
type=str,
|
||||
default='',
|
||||
help='folder name for training outputs.')
|
||||
parser.add_argument("--debug", type=bool, default=True, help="Do not verify commit integrity to run training.")
|
||||
parser.add_argument("--data_path", type=str, default="", help="Defines the data path. It overwrites config.json.")
|
||||
parser.add_argument("--output_path", type=str, help="path for training outputs.", default="")
|
||||
parser.add_argument("--output_folder", type=str, default="", help="folder name for training outputs.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# setup output paths and read configs
|
||||
c = load_config(args.config_path)
|
||||
check_config_speaker_encoder(c)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
if args.data_path != '':
|
||||
if args.data_path != "":
|
||||
c.data_path = args.data_path
|
||||
|
||||
if args.output_path == '':
|
||||
if args.output_path == "":
|
||||
OUT_PATH = os.path.join(_, c.output_path)
|
||||
else:
|
||||
OUT_PATH = args.output_path
|
||||
|
||||
if args.output_folder == '':
|
||||
if args.output_folder == "":
|
||||
OUT_PATH = create_experiment_folder(OUT_PATH, c.run_name, args.debug)
|
||||
else:
|
||||
OUT_PATH = os.path.join(OUT_PATH, args.output_folder)
|
||||
|
@ -259,7 +246,7 @@ if __name__ == '__main__':
|
|||
copy_model_files(c, args.config_path, OUT_PATH, new_fields)
|
||||
|
||||
LOG_DIR = OUT_PATH
|
||||
tb_logger = TensorboardLogger(LOG_DIR, model_name='Speaker_Encoder')
|
||||
tb_logger = TensorboardLogger(LOG_DIR, model_name="Speaker_Encoder")
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -8,6 +8,7 @@ import traceback
|
|||
from random import randrange
|
||||
|
||||
import torch
|
||||
|
||||
# DISTRIBUTED
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -26,8 +27,7 @@ from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
|||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.distribute import init_distributed, reduce_tensor
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.training import NoamLR, setup_torch_training_env
|
||||
|
||||
|
@ -44,19 +44,21 @@ def setup_loader(ap, r, is_val=False, verbose=False):
|
|||
compute_linear_spec=False,
|
||||
meta_data=meta_data_eval if is_val else meta_data_train,
|
||||
ap=ap,
|
||||
tp=c.characters if 'characters' in c.keys() else None,
|
||||
add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
|
||||
batch_group_size=0 if is_val else c.batch_group_size *
|
||||
c.batch_size,
|
||||
tp=c.characters if "characters" in c.keys() else None,
|
||||
add_blank=c["add_blank"] if "add_blank" in c.keys() else False,
|
||||
batch_group_size=0 if is_val else c.batch_group_size * c.batch_size,
|
||||
min_seq_len=c.min_seq_len,
|
||||
max_seq_len=c.max_seq_len,
|
||||
phoneme_cache_path=c.phoneme_cache_path,
|
||||
use_phonemes=c.use_phonemes,
|
||||
phoneme_language=c.phoneme_language,
|
||||
enable_eos_bos=c.enable_eos_bos_chars,
|
||||
use_noise_augment=c['use_noise_augment'] and not is_val,
|
||||
use_noise_augment=c["use_noise_augment"] and not is_val,
|
||||
verbose=verbose,
|
||||
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
|
||||
speaker_mapping=speaker_mapping
|
||||
if c.use_speaker_embedding and c.use_external_speaker_embedding_file
|
||||
else None,
|
||||
)
|
||||
|
||||
if c.use_phonemes and c.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
|
@ -71,9 +73,9 @@ def setup_loader(ap, r, is_val=False, verbose=False):
|
|||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=c.num_val_loader_workers
|
||||
if is_val else c.num_loader_workers,
|
||||
pin_memory=False)
|
||||
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
|
||||
|
@ -95,9 +97,7 @@ def format_data(data):
|
|||
speaker_c = data[8]
|
||||
else:
|
||||
# return speaker_id to be used by an embedding layer
|
||||
speaker_c = [
|
||||
speaker_mapping[speaker_name] for speaker_name in speaker_names
|
||||
]
|
||||
speaker_c = [speaker_mapping[speaker_name] for speaker_name in speaker_names]
|
||||
speaker_c = torch.LongTensor(speaker_c)
|
||||
else:
|
||||
speaker_c = None
|
||||
|
@ -112,13 +112,22 @@ def format_data(data):
|
|||
speaker_c = speaker_c.cuda(non_blocking=True)
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.cuda(non_blocking=True)
|
||||
return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
|
||||
avg_text_length, avg_spec_length, attn_mask, item_idx
|
||||
return (
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
speaker_c,
|
||||
avg_text_length,
|
||||
avg_spec_length,
|
||||
attn_mask,
|
||||
item_idx,
|
||||
)
|
||||
|
||||
|
||||
def data_depended_init(data_loader, model):
|
||||
"""Data depended initialization for activation normalization."""
|
||||
if hasattr(model, 'module'):
|
||||
if hasattr(model, "module"):
|
||||
for f in model.module.decoder.flows:
|
||||
if getattr(f, "set_ddi", False):
|
||||
f.set_ddi(True)
|
||||
|
@ -134,17 +143,15 @@ def data_depended_init(data_loader, model):
|
|||
for _, data in enumerate(data_loader):
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_input, mel_lengths, spekaer_embed,\
|
||||
_, _, attn_mask, _ = format_data(data)
|
||||
text_input, text_lengths, mel_input, mel_lengths, spekaer_embed, _, _, attn_mask, _ = format_data(data)
|
||||
|
||||
# forward pass model
|
||||
_ = model.forward(
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=spekaer_embed)
|
||||
_ = model.forward(text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=spekaer_embed)
|
||||
if num_iter == c.data_dep_init_iter:
|
||||
break
|
||||
num_iter += 1
|
||||
|
||||
if hasattr(model, 'module'):
|
||||
if hasattr(model, "module"):
|
||||
for f in model.module.decoder.flows:
|
||||
if getattr(f, "set_ddi", False):
|
||||
f.set_ddi(False)
|
||||
|
@ -155,15 +162,13 @@ def data_depended_init(data_loader, model):
|
|||
return model
|
||||
|
||||
|
||||
def train(data_loader, model, criterion, optimizer, scheduler,
|
||||
ap, global_step, epoch):
|
||||
def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch):
|
||||
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(
|
||||
len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
end_time = time.time()
|
||||
|
@ -173,8 +178,17 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
|||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
|
||||
avg_text_length, avg_spec_length, attn_mask, _ = format_data(data)
|
||||
(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
speaker_c,
|
||||
avg_text_length,
|
||||
avg_spec_length,
|
||||
attn_mask,
|
||||
_,
|
||||
) = format_data(data)
|
||||
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
|
@ -184,24 +198,22 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
|||
# forward pass model
|
||||
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c)
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c
|
||||
)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
||||
o_dur_log, o_total_dur, text_lengths)
|
||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, o_dur_log, o_total_dur, text_lengths)
|
||||
|
||||
# backward pass with loss scaling
|
||||
if c.mixed_precision:
|
||||
scaler.scale(loss_dict['loss']).backward()
|
||||
scaler.scale(loss_dict["loss"]).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||
c.grad_clip)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss_dict['loss'].backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||
c.grad_clip)
|
||||
loss_dict["loss"].backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
# setup lr
|
||||
|
@ -209,20 +221,20 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
|||
scheduler.step()
|
||||
|
||||
# current_lr
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
current_lr = optimizer.param_groups[0]["lr"]
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
|
||||
loss_dict['align_error'] = align_error
|
||||
loss_dict["align_error"] = align_error
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus)
|
||||
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
|
||||
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
|
||||
loss_dict["log_mle"] = reduce_tensor(loss_dict["log_mle"].data, num_gpus)
|
||||
loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus)
|
||||
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
|
@ -236,9 +248,9 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
|||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values['avg_' + key] = value
|
||||
update_train_values['avg_loader_time'] = loader_time
|
||||
update_train_values['avg_step_time'] = step_time
|
||||
update_train_values["avg_" + key] = value
|
||||
update_train_values["avg_loader_time"] = loader_time
|
||||
update_train_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
# print training progress
|
||||
|
@ -250,26 +262,29 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
|||
"loader_time": [loader_time, 2],
|
||||
"current_lr": current_lr,
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
||||
log_dict, loss_dict, keep_avg.avg_values)
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# Plot Training Iter Stats
|
||||
# reduce TB load
|
||||
if global_step % c.tb_plot_step == 0:
|
||||
iter_stats = {
|
||||
"lr": current_lr,
|
||||
"grad_norm": grad_norm,
|
||||
"step_time": step_time
|
||||
}
|
||||
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
|
||||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters,
|
||||
model_loss=loss_dict['loss'])
|
||||
save_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
global_step,
|
||||
epoch,
|
||||
1,
|
||||
OUT_PATH,
|
||||
model_characters,
|
||||
model_loss=loss_dict["loss"],
|
||||
)
|
||||
|
||||
# wait all kernels to be completed
|
||||
torch.cuda.synchronize()
|
||||
|
@ -278,7 +293,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
|||
# direct pass on model for spec predictions
|
||||
target_speaker = None if speaker_c is None else speaker_c[:1]
|
||||
|
||||
if hasattr(model, 'module'):
|
||||
if hasattr(model, "module"):
|
||||
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
||||
else:
|
||||
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
||||
|
@ -299,9 +314,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
|||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_melspectrogram(const_spec.T)
|
||||
tb_logger.tb_train_audios(global_step,
|
||||
{'TrainAudio': train_audio},
|
||||
c.audio["sample_rate"])
|
||||
tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, c.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
# print epoch stats
|
||||
|
@ -328,16 +341,15 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
|
||||
_, _, attn_mask, _ = format_data(data)
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_c, _, _, attn_mask, _ = format_data(data)
|
||||
|
||||
# forward pass model
|
||||
z, logdet, y_mean, y_log_scale, alignments, o_dur_log, o_total_dur = model.forward(
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c)
|
||||
text_input, text_lengths, mel_input, mel_lengths, attn_mask, g=speaker_c
|
||||
)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths,
|
||||
o_dur_log, o_total_dur, text_lengths)
|
||||
loss_dict = criterion(z, y_mean, y_log_scale, logdet, mel_lengths, o_dur_log, o_total_dur, text_lengths)
|
||||
|
||||
# step time
|
||||
step_time = time.time() - start_time
|
||||
|
@ -345,13 +357,13 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
|
||||
# compute alignment score
|
||||
align_error = 1 - alignment_diagonal_score(alignments)
|
||||
loss_dict['align_error'] = align_error
|
||||
loss_dict["align_error"] = align_error
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict['log_mle'] = reduce_tensor(loss_dict['log_mle'].data, num_gpus)
|
||||
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
|
||||
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
|
||||
loss_dict["log_mle"] = reduce_tensor(loss_dict["log_mle"].data, num_gpus)
|
||||
loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus)
|
||||
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
|
@ -365,7 +377,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values['avg_' + key] = value
|
||||
update_train_values["avg_" + key] = value
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
if c.print_eval:
|
||||
|
@ -375,7 +387,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
# Diagnostic visualizations
|
||||
# direct pass on model for spec predictions
|
||||
target_speaker = None if speaker_c is None else speaker_c[:1]
|
||||
if hasattr(model, 'module'):
|
||||
if hasattr(model, "module"):
|
||||
spec_pred, *_ = model.module.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
||||
else:
|
||||
spec_pred, *_ = model.inference(text_input[:1], text_lengths[:1], g=target_speaker)
|
||||
|
@ -389,13 +401,12 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
eval_figures = {
|
||||
"prediction": plot_spectrogram(const_spec, ap),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||
"alignment": plot_alignment(align_img)
|
||||
"alignment": plot_alignment(align_img),
|
||||
}
|
||||
|
||||
# Sample audio
|
||||
eval_audio = ap.inv_melspectrogram(const_spec.T)
|
||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
|
||||
c.audio["sample_rate"])
|
||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"])
|
||||
|
||||
# Plot Validation Stats
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
|
@ -408,7 +419,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963."
|
||||
"Prior to November 22, 1963.",
|
||||
]
|
||||
else:
|
||||
with open(c.test_sentences_file, "r") as f:
|
||||
|
@ -420,7 +431,9 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
print(" | > Synthesizing test sentences")
|
||||
if c.use_speaker_embedding:
|
||||
if c.use_external_speaker_embedding_file:
|
||||
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding']
|
||||
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]][
|
||||
"embedding"
|
||||
]
|
||||
speaker_id = None
|
||||
else:
|
||||
speaker_id = 0
|
||||
|
@ -442,25 +455,22 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
speaker_embedding=speaker_embedding,
|
||||
style_wav=style_wav,
|
||||
truncated=False,
|
||||
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
|
||||
enable_eos_bos_chars=c.enable_eos_bos_chars, # pylint: disable=unused-argument
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False)
|
||||
do_trim_silence=False,
|
||||
)
|
||||
|
||||
file_path = os.path.join(AUDIO_PATH, str(global_step))
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
file_path = os.path.join(file_path,
|
||||
"TestSentence_{}.wav".format(idx))
|
||||
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
|
||||
ap.save_wav(wav, file_path)
|
||||
test_audios['{}-audio'.format(idx)] = wav
|
||||
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
|
||||
postnet_output, ap)
|
||||
test_figures['{}-alignment'.format(idx)] = plot_alignment(
|
||||
alignment)
|
||||
except: #pylint: disable=bare-except
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment)
|
||||
except: # pylint: disable=bare-except
|
||||
print(" !! Error creating Test Sentence -", idx)
|
||||
traceback.print_exc()
|
||||
tb_logger.tb_test_audios(global_step, test_audios,
|
||||
c.audio['sample_rate'])
|
||||
tb_logger.tb_test_audios(global_step, test_audios, c.audio["sample_rate"])
|
||||
tb_logger.tb_test_figures(global_step, test_figures)
|
||||
return keep_avg.avg_values
|
||||
|
||||
|
@ -470,13 +480,12 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
if 'characters' in c.keys():
|
||||
if "characters" in c.keys():
|
||||
symbols, phonemes = make_symbols(**c.characters)
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id,
|
||||
c.distributed["backend"], c.distributed["url"])
|
||||
init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"])
|
||||
|
||||
# set model characters
|
||||
model_characters = phonemes if c.use_phonemes else symbols
|
||||
|
@ -486,10 +495,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
||||
|
||||
# set the portion of the data used for training
|
||||
if 'train_portion' in c.keys():
|
||||
meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)]
|
||||
if 'eval_portion' in c.keys():
|
||||
meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
|
||||
if "train_portion" in c.keys():
|
||||
meta_data_train = meta_data_train[: int(len(meta_data_train) * c.train_portion)]
|
||||
if "eval_portion" in c.keys():
|
||||
meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * c.eval_portion)]
|
||||
|
||||
# parse speakers
|
||||
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH)
|
||||
|
@ -501,26 +510,25 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
if args.restore_path:
|
||||
print(f" > Restoring from {os.path.basename(args.restore_path)} ...")
|
||||
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||
checkpoint = torch.load(args.restore_path, map_location="cpu")
|
||||
try:
|
||||
# TODO: fix optimizer init, model.cuda() needs to be called before
|
||||
# optimizer restore
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
if c.reinit_layers:
|
||||
raise RuntimeError
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
except: #pylint: disable=bare-except
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
except: # pylint: disable=bare-except
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group['initial_lr'] = c.lr
|
||||
print(f" > Model restored from step {checkpoint['step']:d}",
|
||||
flush=True)
|
||||
args.restore_step = checkpoint['step']
|
||||
group["initial_lr"] = c.lr
|
||||
print(f" > Model restored from step {checkpoint['step']:d}", flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
|
@ -533,9 +541,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
model = DDP_th(model, device_ids=[args.rank])
|
||||
|
||||
if c.noam_schedule:
|
||||
scheduler = NoamLR(optimizer,
|
||||
warmup_steps=c.warmup_steps,
|
||||
last_epoch=args.restore_step - 1)
|
||||
scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
|
@ -543,16 +549,14 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
print("\n > Model has {} parameters".format(num_params), flush=True)
|
||||
|
||||
if args.restore_step == 0 or not args.best_path:
|
||||
best_loss = float('inf')
|
||||
best_loss = float("inf")
|
||||
print(" > Starting with inf best loss.")
|
||||
else:
|
||||
print(" > Restoring best loss from "
|
||||
f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path,
|
||||
map_location='cpu')['model_loss']
|
||||
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {best_loss}.")
|
||||
keep_all_best = c.get('keep_all_best', False)
|
||||
keep_after = c.get('keep_after', 10000) # void if keep_all_best False
|
||||
keep_all_best = c.get("keep_all_best", False)
|
||||
keep_after = c.get("keep_after", 10000) # void if keep_all_best False
|
||||
|
||||
# define dataloaders
|
||||
train_loader = setup_loader(ap, 1, is_val=False, verbose=True)
|
||||
|
@ -562,25 +566,32 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
model = data_depended_init(train_loader, model)
|
||||
for epoch in range(0, c.epochs):
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
train_avg_loss_dict, global_step = train(train_loader, model,
|
||||
criterion, optimizer,
|
||||
scheduler, ap, global_step,
|
||||
epoch)
|
||||
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
|
||||
global_step, epoch)
|
||||
train_avg_loss_dict, global_step = train(
|
||||
train_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch
|
||||
)
|
||||
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = train_avg_loss_dict['avg_loss']
|
||||
target_loss = train_avg_loss_dict["avg_loss"]
|
||||
if c.run_eval:
|
||||
target_loss = eval_avg_loss_dict['avg_loss']
|
||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer,
|
||||
global_step, epoch, c.r, OUT_PATH, model_characters,
|
||||
keep_all_best=keep_all_best, keep_after=keep_after)
|
||||
target_loss = eval_avg_loss_dict["avg_loss"]
|
||||
best_loss = save_best_model(
|
||||
target_loss,
|
||||
best_loss,
|
||||
model,
|
||||
optimizer,
|
||||
global_step,
|
||||
epoch,
|
||||
c.r,
|
||||
OUT_PATH,
|
||||
model_characters,
|
||||
keep_all_best=keep_all_best,
|
||||
keep_after=keep_after,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, model_class='tts')
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts")
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -10,6 +10,7 @@ from random import randrange
|
|||
|
||||
import torch
|
||||
from TTS.utils.arguments import parse_arguments, process_args
|
||||
|
||||
# DISTRIBUTED
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.utils.data import DataLoader
|
||||
|
@ -26,8 +27,7 @@ from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
|||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.distribute import init_distributed, reduce_tensor
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.training import NoamLR, setup_torch_training_env
|
||||
|
||||
|
@ -44,10 +44,9 @@ def setup_loader(ap, r, is_val=False, verbose=False):
|
|||
compute_linear_spec=False,
|
||||
meta_data=meta_data_eval if is_val else meta_data_train,
|
||||
ap=ap,
|
||||
tp=c.characters if 'characters' in c.keys() else None,
|
||||
add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
|
||||
batch_group_size=0 if is_val else c.batch_group_size *
|
||||
c.batch_size,
|
||||
tp=c.characters if "characters" in c.keys() else None,
|
||||
add_blank=c["add_blank"] if "add_blank" in c.keys() else False,
|
||||
batch_group_size=0 if is_val else c.batch_group_size * c.batch_size,
|
||||
min_seq_len=c.min_seq_len,
|
||||
max_seq_len=c.max_seq_len,
|
||||
phoneme_cache_path=c.phoneme_cache_path,
|
||||
|
@ -56,7 +55,10 @@ def setup_loader(ap, r, is_val=False, verbose=False):
|
|||
enable_eos_bos=c.enable_eos_bos_chars,
|
||||
use_noise_augment=not is_val,
|
||||
verbose=verbose,
|
||||
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
|
||||
speaker_mapping=speaker_mapping
|
||||
if c.use_speaker_embedding and c.use_external_speaker_embedding_file
|
||||
else None,
|
||||
)
|
||||
|
||||
if c.use_phonemes and c.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
|
@ -71,9 +73,9 @@ def setup_loader(ap, r, is_val=False, verbose=False):
|
|||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=c.num_val_loader_workers
|
||||
if is_val else c.num_loader_workers,
|
||||
pin_memory=False)
|
||||
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
|
||||
|
@ -95,9 +97,7 @@ def format_data(data):
|
|||
speaker_c = data[8]
|
||||
else:
|
||||
# return speaker_id to be used by an embedding layer
|
||||
speaker_c = [
|
||||
speaker_mapping[speaker_name] for speaker_name in speaker_names
|
||||
]
|
||||
speaker_c = [speaker_mapping[speaker_name] for speaker_name in speaker_names]
|
||||
speaker_c = torch.LongTensor(speaker_c)
|
||||
else:
|
||||
speaker_c = None
|
||||
|
@ -105,7 +105,7 @@ def format_data(data):
|
|||
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
|
||||
for idx, am in enumerate(attn_mask):
|
||||
# compute raw durations
|
||||
c_idxs = am[:, :text_lengths[idx], :mel_lengths[idx]].max(1)[1]
|
||||
c_idxs = am[:, : text_lengths[idx], : mel_lengths[idx]].max(1)[1]
|
||||
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
|
||||
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
|
||||
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
|
||||
|
@ -115,8 +115,10 @@ def format_data(data):
|
|||
extra_frames = dur.sum() - mel_lengths[idx]
|
||||
largest_idxs = torch.argsort(-dur)[:extra_frames]
|
||||
dur[largest_idxs] -= 1
|
||||
assert dur.sum() == mel_lengths[idx], f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
||||
durations[idx, :text_lengths[idx]] = dur
|
||||
assert (
|
||||
dur.sum() == mel_lengths[idx]
|
||||
), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
||||
durations[idx, : text_lengths[idx]] = dur
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
text_input = text_input.cuda(non_blocking=True)
|
||||
|
@ -127,19 +129,27 @@ def format_data(data):
|
|||
speaker_c = speaker_c.cuda(non_blocking=True)
|
||||
attn_mask = attn_mask.cuda(non_blocking=True)
|
||||
durations = durations.cuda(non_blocking=True)
|
||||
return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
|
||||
avg_text_length, avg_spec_length, attn_mask, durations, item_idx
|
||||
return (
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
speaker_c,
|
||||
avg_text_length,
|
||||
avg_spec_length,
|
||||
attn_mask,
|
||||
durations,
|
||||
item_idx,
|
||||
)
|
||||
|
||||
|
||||
def train(data_loader, model, criterion, optimizer, scheduler,
|
||||
ap, global_step, epoch):
|
||||
def train(data_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch):
|
||||
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(
|
||||
len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
end_time = time.time()
|
||||
|
@ -149,8 +159,18 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
|||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\
|
||||
avg_text_length, avg_spec_length, _, dur_target, _ = format_data(data)
|
||||
(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_targets,
|
||||
mel_lengths,
|
||||
speaker_c,
|
||||
avg_text_length,
|
||||
avg_spec_length,
|
||||
_,
|
||||
dur_target,
|
||||
_,
|
||||
) = format_data(data)
|
||||
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
|
@ -160,23 +180,24 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
|||
# forward pass model
|
||||
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
||||
decoder_output, dur_output, alignments = model.forward(
|
||||
text_input, text_lengths, mel_lengths, dur_target, g=speaker_c)
|
||||
text_input, text_lengths, mel_lengths, dur_target, g=speaker_c
|
||||
)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths)
|
||||
loss_dict = criterion(
|
||||
decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths
|
||||
)
|
||||
|
||||
# backward pass with loss scaling
|
||||
if c.mixed_precision:
|
||||
scaler.scale(loss_dict['loss']).backward()
|
||||
scaler.scale(loss_dict["loss"]).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||
c.grad_clip)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss_dict['loss'].backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||
c.grad_clip)
|
||||
loss_dict["loss"].backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
# setup lr
|
||||
|
@ -184,21 +205,21 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
|||
scheduler.step()
|
||||
|
||||
# current_lr
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
current_lr = optimizer.param_groups[0]["lr"]
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
|
||||
loss_dict['align_error'] = align_error
|
||||
loss_dict["align_error"] = align_error
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, num_gpus)
|
||||
loss_dict['loss_ssim'] = reduce_tensor(loss_dict['loss_ssim'].data, num_gpus)
|
||||
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
|
||||
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
|
||||
loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus)
|
||||
loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus)
|
||||
loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus)
|
||||
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
|
@ -212,41 +233,43 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
|||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values['avg_' + key] = value
|
||||
update_train_values['avg_loader_time'] = loader_time
|
||||
update_train_values['avg_step_time'] = step_time
|
||||
update_train_values["avg_" + key] = value
|
||||
update_train_values["avg_loader_time"] = loader_time
|
||||
update_train_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
# print training progress
|
||||
if global_step % c.print_step == 0:
|
||||
log_dict = {
|
||||
|
||||
"avg_spec_length": [avg_spec_length, 1], # value, precision
|
||||
"avg_text_length": [avg_text_length, 1],
|
||||
"step_time": [step_time, 4],
|
||||
"loader_time": [loader_time, 2],
|
||||
"current_lr": current_lr,
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
||||
log_dict, loss_dict, keep_avg.avg_values)
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# Plot Training Iter Stats
|
||||
# reduce TB load
|
||||
if global_step % c.tb_plot_step == 0:
|
||||
iter_stats = {
|
||||
"lr": current_lr,
|
||||
"grad_norm": grad_norm,
|
||||
"step_time": step_time
|
||||
}
|
||||
iter_stats = {"lr": current_lr, "grad_norm": grad_norm, "step_time": step_time}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
|
||||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH, model_characters,
|
||||
model_loss=loss_dict['loss'])
|
||||
save_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
global_step,
|
||||
epoch,
|
||||
1,
|
||||
OUT_PATH,
|
||||
model_characters,
|
||||
model_loss=loss_dict["loss"],
|
||||
)
|
||||
|
||||
# wait all kernels to be completed
|
||||
torch.cuda.synchronize()
|
||||
|
@ -267,9 +290,7 @@ def train(data_loader, model, criterion, optimizer, scheduler,
|
|||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
tb_logger.tb_train_audios(global_step,
|
||||
{'TrainAudio': train_audio},
|
||||
c.audio["sample_rate"])
|
||||
tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, c.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
# print epoch stats
|
||||
|
@ -296,16 +317,18 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\
|
||||
_, _, _, dur_target, _ = format_data(data)
|
||||
text_input, text_lengths, mel_targets, mel_lengths, speaker_c, _, _, _, dur_target, _ = format_data(data)
|
||||
|
||||
# forward pass model
|
||||
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
||||
decoder_output, dur_output, alignments = model.forward(
|
||||
text_input, text_lengths, mel_lengths, dur_target, g=speaker_c)
|
||||
text_input, text_lengths, mel_lengths, dur_target, g=speaker_c
|
||||
)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths)
|
||||
loss_dict = criterion(
|
||||
decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths
|
||||
)
|
||||
|
||||
# step time
|
||||
step_time = time.time() - start_time
|
||||
|
@ -313,14 +336,14 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
|
||||
# compute alignment score
|
||||
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
|
||||
loss_dict['align_error'] = align_error
|
||||
loss_dict["align_error"] = align_error
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, num_gpus)
|
||||
loss_dict['loss_ssim'] = reduce_tensor(loss_dict['loss_ssim'].data, num_gpus)
|
||||
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
|
||||
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
|
||||
loss_dict["loss_l1"] = reduce_tensor(loss_dict["loss_l1"].data, num_gpus)
|
||||
loss_dict["loss_ssim"] = reduce_tensor(loss_dict["loss_ssim"].data, num_gpus)
|
||||
loss_dict["loss_dur"] = reduce_tensor(loss_dict["loss_dur"].data, num_gpus)
|
||||
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
|
@ -334,7 +357,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values['avg_' + key] = value
|
||||
update_train_values["avg_" + key] = value
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
if c.print_eval:
|
||||
|
@ -350,13 +373,12 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
eval_figures = {
|
||||
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||
"alignment": plot_alignment(align_img, output_fig=False)
|
||||
"alignment": plot_alignment(align_img, output_fig=False),
|
||||
}
|
||||
|
||||
# Sample audio
|
||||
eval_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
|
||||
c.audio["sample_rate"])
|
||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"])
|
||||
|
||||
# Plot Validation Stats
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
|
@ -369,7 +391,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963."
|
||||
"Prior to November 22, 1963.",
|
||||
]
|
||||
else:
|
||||
with open(c.test_sentences_file, "r") as f:
|
||||
|
@ -381,7 +403,9 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
print(" | > Synthesizing test sentences")
|
||||
if c.use_speaker_embedding:
|
||||
if c.use_external_speaker_embedding_file:
|
||||
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding']
|
||||
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]][
|
||||
"embedding"
|
||||
]
|
||||
speaker_id = None
|
||||
else:
|
||||
speaker_id = 0
|
||||
|
@ -403,25 +427,22 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
speaker_embedding=speaker_embedding,
|
||||
style_wav=style_wav,
|
||||
truncated=False,
|
||||
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
|
||||
enable_eos_bos_chars=c.enable_eos_bos_chars, # pylint: disable=unused-argument
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False)
|
||||
do_trim_silence=False,
|
||||
)
|
||||
|
||||
file_path = os.path.join(AUDIO_PATH, str(global_step))
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
file_path = os.path.join(file_path,
|
||||
"TestSentence_{}.wav".format(idx))
|
||||
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
|
||||
ap.save_wav(wav, file_path)
|
||||
test_audios['{}-audio'.format(idx)] = wav
|
||||
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
|
||||
postnet_output, ap)
|
||||
test_figures['{}-alignment'.format(idx)] = plot_alignment(
|
||||
alignment)
|
||||
except: #pylint: disable=bare-except
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment)
|
||||
except: # pylint: disable=bare-except
|
||||
print(" !! Error creating Test Sentence -", idx)
|
||||
traceback.print_exc()
|
||||
tb_logger.tb_test_audios(global_step, test_audios,
|
||||
c.audio['sample_rate'])
|
||||
tb_logger.tb_test_audios(global_step, test_audios, c.audio["sample_rate"])
|
||||
tb_logger.tb_test_figures(global_step, test_figures)
|
||||
return keep_avg.avg_values
|
||||
|
||||
|
@ -432,13 +453,12 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
global meta_data_train, meta_data_eval, symbols, phonemes, model_characters, speaker_mapping
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
if 'characters' in c.keys():
|
||||
if "characters" in c.keys():
|
||||
symbols, phonemes = make_symbols(**c.characters)
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id,
|
||||
c.distributed["backend"], c.distributed["url"])
|
||||
init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"])
|
||||
|
||||
# set model characters
|
||||
model_characters = phonemes if c.use_phonemes else symbols
|
||||
|
@ -448,10 +468,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True)
|
||||
|
||||
# set the portion of the data used for training if set in config.json
|
||||
if 'train_portion' in c.keys():
|
||||
meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)]
|
||||
if 'eval_portion' in c.keys():
|
||||
meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
|
||||
if "train_portion" in c.keys():
|
||||
meta_data_train = meta_data_train[: int(len(meta_data_train) * c.train_portion)]
|
||||
if "eval_portion" in c.keys():
|
||||
meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * c.eval_portion)]
|
||||
|
||||
# parse speakers
|
||||
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH)
|
||||
|
@ -463,26 +483,25 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
if args.restore_path:
|
||||
print(f" > Restoring from {os.path.basename(args.restore_path)} ...")
|
||||
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||
checkpoint = torch.load(args.restore_path, map_location="cpu")
|
||||
try:
|
||||
# TODO: fix optimizer init, model.cuda() needs to be called before
|
||||
# optimizer restore
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
if c.reinit_layers:
|
||||
raise RuntimeError
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
except: #pylint: disable=bare-except
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
except: # pylint: disable=bare-except
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group['initial_lr'] = c.lr
|
||||
print(" > Model restored from step %d" % checkpoint['step'],
|
||||
flush=True)
|
||||
args.restore_step = checkpoint['step']
|
||||
group["initial_lr"] = c.lr
|
||||
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
|
@ -495,9 +514,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
model = DDP_th(model, device_ids=[args.rank])
|
||||
|
||||
if c.noam_schedule:
|
||||
scheduler = NoamLR(optimizer,
|
||||
warmup_steps=c.warmup_steps,
|
||||
last_epoch=args.restore_step - 1)
|
||||
scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
|
@ -505,16 +522,14 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
print("\n > Model has {} parameters".format(num_params), flush=True)
|
||||
|
||||
if args.restore_step == 0 or not args.best_path:
|
||||
best_loss = float('inf')
|
||||
best_loss = float("inf")
|
||||
print(" > Starting with inf best loss.")
|
||||
else:
|
||||
print(" > Restoring best loss from "
|
||||
f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path,
|
||||
map_location='cpu')['model_loss']
|
||||
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {best_loss}.")
|
||||
keep_all_best = c.get('keep_all_best', False)
|
||||
keep_after = c.get('keep_after', 10000) # void if keep_all_best False
|
||||
keep_all_best = c.get("keep_all_best", False)
|
||||
keep_after = c.get("keep_after", 10000) # void if keep_all_best False
|
||||
|
||||
# define dataloaders
|
||||
train_loader = setup_loader(ap, 1, is_val=False, verbose=True)
|
||||
|
@ -523,24 +538,32 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
global_step = args.restore_step
|
||||
for epoch in range(0, c.epochs):
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer,
|
||||
scheduler, ap, global_step,
|
||||
epoch)
|
||||
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
|
||||
global_step, epoch)
|
||||
train_avg_loss_dict, global_step = train(
|
||||
train_loader, model, criterion, optimizer, scheduler, ap, global_step, epoch
|
||||
)
|
||||
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = train_avg_loss_dict['avg_loss']
|
||||
target_loss = train_avg_loss_dict["avg_loss"]
|
||||
if c.run_eval:
|
||||
target_loss = eval_avg_loss_dict['avg_loss']
|
||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer,
|
||||
global_step, epoch, c.r, OUT_PATH, model_characters,
|
||||
keep_all_best=keep_all_best, keep_after=keep_after)
|
||||
target_loss = eval_avg_loss_dict["avg_loss"]
|
||||
best_loss = save_best_model(
|
||||
target_loss,
|
||||
best_loss,
|
||||
model,
|
||||
optimizer,
|
||||
global_step,
|
||||
epoch,
|
||||
c.r,
|
||||
OUT_PATH,
|
||||
model_characters,
|
||||
keep_all_best=keep_all_best,
|
||||
keep_after=keep_after,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, model_class='tts')
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts")
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -22,14 +22,17 @@ from TTS.tts.utils.synthesis import synthesis
|
|||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.distribute import (DistributedSampler, apply_gradient_allreduce,
|
||||
init_distributed, reduce_tensor)
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.distribute import DistributedSampler, apply_gradient_allreduce, init_distributed, reduce_tensor
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.training import (NoamLR, adam_weight_decay, check_update,
|
||||
gradual_training_scheduler, set_weight_decay,
|
||||
setup_torch_training_env)
|
||||
from TTS.utils.training import (
|
||||
NoamLR,
|
||||
adam_weight_decay,
|
||||
check_update,
|
||||
gradual_training_scheduler,
|
||||
set_weight_decay,
|
||||
setup_torch_training_env,
|
||||
)
|
||||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||
|
||||
|
@ -42,13 +45,12 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
|
|||
dataset = MyDataset(
|
||||
r,
|
||||
c.text_cleaner,
|
||||
compute_linear_spec=c.model.lower() == 'tacotron',
|
||||
compute_linear_spec=c.model.lower() == "tacotron",
|
||||
meta_data=meta_data_eval if is_val else meta_data_train,
|
||||
ap=ap,
|
||||
tp=c.characters if 'characters' in c.keys() else None,
|
||||
add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
|
||||
batch_group_size=0 if is_val else c.batch_group_size *
|
||||
c.batch_size,
|
||||
tp=c.characters if "characters" in c.keys() else None,
|
||||
add_blank=c["add_blank"] if "add_blank" in c.keys() else False,
|
||||
batch_group_size=0 if is_val else c.batch_group_size * c.batch_size,
|
||||
min_seq_len=c.min_seq_len,
|
||||
max_seq_len=c.max_seq_len,
|
||||
phoneme_cache_path=c.phoneme_cache_path,
|
||||
|
@ -56,11 +58,10 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
|
|||
phoneme_language=c.phoneme_language,
|
||||
enable_eos_bos=c.enable_eos_bos_chars,
|
||||
verbose=verbose,
|
||||
speaker_mapping=(speaker_mapping if (
|
||||
c.use_speaker_embedding
|
||||
and c.use_external_speaker_embedding_file
|
||||
) else None)
|
||||
)
|
||||
speaker_mapping=(
|
||||
speaker_mapping if (c.use_speaker_embedding and c.use_external_speaker_embedding_file) else None
|
||||
),
|
||||
)
|
||||
|
||||
if c.use_phonemes and c.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
|
@ -75,11 +76,12 @@ def setup_loader(ap, r, is_val=False, verbose=False, dataset=None):
|
|||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=c.num_val_loader_workers
|
||||
if is_val else c.num_loader_workers,
|
||||
pin_memory=False)
|
||||
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
|
||||
def format_data(data):
|
||||
# setup input data
|
||||
text_input = data[0]
|
||||
|
@ -97,21 +99,16 @@ def format_data(data):
|
|||
speaker_embeddings = data[8]
|
||||
speaker_ids = None
|
||||
else:
|
||||
speaker_ids = [
|
||||
speaker_mapping[speaker_name] for speaker_name in speaker_names
|
||||
]
|
||||
speaker_ids = [speaker_mapping[speaker_name] for speaker_name in speaker_names]
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
speaker_embeddings = None
|
||||
else:
|
||||
speaker_embeddings = None
|
||||
speaker_ids = None
|
||||
|
||||
|
||||
# set stop targets view, we predict a single stop token per iteration.
|
||||
stop_targets = stop_targets.view(text_input.shape[0],
|
||||
stop_targets.size(1) // c.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) >
|
||||
0.0).unsqueeze(2).float().squeeze(2)
|
||||
stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // c.r, -1)
|
||||
stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2)
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
|
@ -126,17 +123,26 @@ def format_data(data):
|
|||
if speaker_embeddings is not None:
|
||||
speaker_embeddings = speaker_embeddings.cuda(non_blocking=True)
|
||||
|
||||
return text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length
|
||||
return (
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
linear_input,
|
||||
stop_targets,
|
||||
speaker_ids,
|
||||
speaker_embeddings,
|
||||
max_text_length,
|
||||
max_spec_length,
|
||||
)
|
||||
|
||||
|
||||
def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
|
||||
ap, global_step, epoch, scaler, scaler_st):
|
||||
def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler, ap, global_step, epoch, scaler, scaler_st):
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(
|
||||
len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
end_time = time.time()
|
||||
|
@ -145,7 +151,18 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
|
|||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, max_text_length, max_spec_length = format_data(data)
|
||||
(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
linear_input,
|
||||
stop_targets,
|
||||
speaker_ids,
|
||||
speaker_embeddings,
|
||||
max_text_length,
|
||||
max_spec_length,
|
||||
) = format_data(data)
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
|
@ -161,35 +178,65 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
|
|||
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
||||
# forward pass model
|
||||
if c.bidirectional_decoder or c.double_decoder_consistency:
|
||||
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
|
||||
(
|
||||
decoder_output,
|
||||
postnet_output,
|
||||
alignments,
|
||||
stop_tokens,
|
||||
decoder_backward_output,
|
||||
alignments_backward,
|
||||
) = model(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
speaker_ids=speaker_ids,
|
||||
speaker_embeddings=speaker_embeddings,
|
||||
)
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||
text_input, text_lengths, mel_input, mel_lengths, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
speaker_ids=speaker_ids,
|
||||
speaker_embeddings=speaker_embeddings,
|
||||
)
|
||||
decoder_backward_output = None
|
||||
alignments_backward = None
|
||||
|
||||
# set the [alignment] lengths wrt reduction factor for guided attention
|
||||
if mel_lengths.max() % model.decoder.r != 0:
|
||||
alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r
|
||||
alignment_lengths = (
|
||||
mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))
|
||||
) // model.decoder.r
|
||||
else:
|
||||
alignment_lengths = mel_lengths // model.decoder.r
|
||||
alignment_lengths = mel_lengths // model.decoder.r
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(postnet_output, decoder_output, mel_input,
|
||||
linear_input, stop_tokens, stop_targets,
|
||||
mel_lengths, decoder_backward_output,
|
||||
alignments, alignment_lengths,
|
||||
alignments_backward, text_lengths)
|
||||
loss_dict = criterion(
|
||||
postnet_output,
|
||||
decoder_output,
|
||||
mel_input,
|
||||
linear_input,
|
||||
stop_tokens,
|
||||
stop_targets,
|
||||
mel_lengths,
|
||||
decoder_backward_output,
|
||||
alignments,
|
||||
alignment_lengths,
|
||||
alignments_backward,
|
||||
text_lengths,
|
||||
)
|
||||
|
||||
# check nan loss
|
||||
if torch.isnan(loss_dict['loss']).any():
|
||||
raise RuntimeError(f'Detected NaN loss at step {global_step}.')
|
||||
if torch.isnan(loss_dict["loss"]).any():
|
||||
raise RuntimeError(f"Detected NaN loss at step {global_step}.")
|
||||
|
||||
# optimizer step
|
||||
if c.mixed_precision:
|
||||
# model optimizer step in mixed precision mode
|
||||
scaler.scale(loss_dict['loss']).backward()
|
||||
scaler.scale(loss_dict["loss"]).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
optimizer, current_lr = adam_weight_decay(optimizer)
|
||||
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
|
||||
|
@ -198,7 +245,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
|
|||
|
||||
# stopnet optimizer step
|
||||
if c.separate_stopnet:
|
||||
scaler_st.scale(loss_dict['stopnet_loss']).backward()
|
||||
scaler_st.scale(loss_dict["stopnet_loss"]).backward()
|
||||
scaler.unscale_(optimizer_st)
|
||||
optimizer_st, _ = adam_weight_decay(optimizer_st)
|
||||
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
||||
|
@ -208,14 +255,14 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
|
|||
grad_norm_st = 0
|
||||
else:
|
||||
# main model optimizer step
|
||||
loss_dict['loss'].backward()
|
||||
loss_dict["loss"].backward()
|
||||
optimizer, current_lr = adam_weight_decay(optimizer)
|
||||
grad_norm, _ = check_update(model, c.grad_clip, ignore_stopnet=True)
|
||||
optimizer.step()
|
||||
|
||||
# stopnet optimizer step
|
||||
if c.separate_stopnet:
|
||||
loss_dict['stopnet_loss'].backward()
|
||||
loss_dict["stopnet_loss"].backward()
|
||||
optimizer_st, _ = adam_weight_decay(optimizer_st)
|
||||
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
||||
optimizer_st.step()
|
||||
|
@ -224,17 +271,19 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
|
|||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(alignments)
|
||||
loss_dict['align_error'] = align_error
|
||||
loss_dict["align_error"] = align_error
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus)
|
||||
loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus)
|
||||
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
|
||||
loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus) if c.stopnet else loss_dict['stopnet_loss']
|
||||
loss_dict["postnet_loss"] = reduce_tensor(loss_dict["postnet_loss"].data, num_gpus)
|
||||
loss_dict["decoder_loss"] = reduce_tensor(loss_dict["decoder_loss"].data, num_gpus)
|
||||
loss_dict["loss"] = reduce_tensor(loss_dict["loss"].data, num_gpus)
|
||||
loss_dict["stopnet_loss"] = (
|
||||
reduce_tensor(loss_dict["stopnet_loss"].data, num_gpus) if c.stopnet else loss_dict["stopnet_loss"]
|
||||
)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
|
@ -248,9 +297,9 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
|
|||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values['avg_' + key] = value
|
||||
update_train_values['avg_loader_time'] = loader_time
|
||||
update_train_values['avg_step_time'] = step_time
|
||||
update_train_values["avg_" + key] = value
|
||||
update_train_values["avg_loader_time"] = loader_time
|
||||
update_train_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
# print training progress
|
||||
|
@ -262,8 +311,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
|
|||
"loader_time": [loader_time, 2],
|
||||
"current_lr": current_lr,
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
||||
log_dict, loss_dict, keep_avg.avg_values)
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# Plot Training Iter Stats
|
||||
|
@ -273,7 +321,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
|
|||
"lr": current_lr,
|
||||
"grad_norm": grad_norm,
|
||||
"grad_norm_st": grad_norm_st,
|
||||
"step_time": step_time
|
||||
"step_time": step_time,
|
||||
}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
|
@ -281,17 +329,26 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
|
|||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model, optimizer, global_step, epoch, model.decoder.r, OUT_PATH,
|
||||
optimizer_st=optimizer_st,
|
||||
model_loss=loss_dict['postnet_loss'],
|
||||
characters=model_characters,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None)
|
||||
save_checkpoint(
|
||||
model,
|
||||
optimizer,
|
||||
global_step,
|
||||
epoch,
|
||||
model.decoder.r,
|
||||
OUT_PATH,
|
||||
optimizer_st=optimizer_st,
|
||||
model_loss=loss_dict["postnet_loss"],
|
||||
characters=model_characters,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None,
|
||||
)
|
||||
|
||||
# Diagnostic visualizations
|
||||
const_spec = postnet_output[0].data.cpu().numpy()
|
||||
gt_spec = linear_input[0].data.cpu().numpy() if c.model in [
|
||||
"Tacotron", "TacotronGST"
|
||||
] else mel_input[0].data.cpu().numpy()
|
||||
gt_spec = (
|
||||
linear_input[0].data.cpu().numpy()
|
||||
if c.model in ["Tacotron", "TacotronGST"]
|
||||
else mel_input[0].data.cpu().numpy()
|
||||
)
|
||||
align_img = alignments[0].data.cpu().numpy()
|
||||
|
||||
figures = {
|
||||
|
@ -301,7 +358,9 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
|
|||
}
|
||||
|
||||
if c.bidirectional_decoder or c.double_decoder_consistency:
|
||||
figures["alignment_backward"] = plot_alignment(alignments_backward[0].data.cpu().numpy(), output_fig=False)
|
||||
figures["alignment_backward"] = plot_alignment(
|
||||
alignments_backward[0].data.cpu().numpy(), output_fig=False
|
||||
)
|
||||
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
|
@ -310,9 +369,7 @@ def train(data_loader, model, criterion, optimizer, optimizer_st, scheduler,
|
|||
train_audio = ap.inv_spectrogram(const_spec.T)
|
||||
else:
|
||||
train_audio = ap.inv_melspectrogram(const_spec.T)
|
||||
tb_logger.tb_train_audios(global_step,
|
||||
{'TrainAudio': train_audio},
|
||||
c.audio["sample_rate"])
|
||||
tb_logger.tb_train_audios(global_step, {"TrainAudio": train_audio}, c.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
# print epoch stats
|
||||
|
@ -339,31 +396,62 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, speaker_embeddings, _, _ = format_data(data)
|
||||
(
|
||||
text_input,
|
||||
text_lengths,
|
||||
mel_input,
|
||||
mel_lengths,
|
||||
linear_input,
|
||||
stop_targets,
|
||||
speaker_ids,
|
||||
speaker_embeddings,
|
||||
_,
|
||||
_,
|
||||
) = format_data(data)
|
||||
assert mel_input.shape[1] % model.decoder.r == 0
|
||||
|
||||
# forward pass model
|
||||
if c.bidirectional_decoder or c.double_decoder_consistency:
|
||||
decoder_output, postnet_output, alignments, stop_tokens, decoder_backward_output, alignments_backward = model(
|
||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
|
||||
(
|
||||
decoder_output,
|
||||
postnet_output,
|
||||
alignments,
|
||||
stop_tokens,
|
||||
decoder_backward_output,
|
||||
alignments_backward,
|
||||
) = model(
|
||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings)
|
||||
text_input, text_lengths, mel_input, speaker_ids=speaker_ids, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
decoder_backward_output = None
|
||||
alignments_backward = None
|
||||
|
||||
# set the alignment lengths wrt reduction factor for guided attention
|
||||
if mel_lengths.max() % model.decoder.r != 0:
|
||||
alignment_lengths = (mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))) // model.decoder.r
|
||||
alignment_lengths = (
|
||||
mel_lengths + (model.decoder.r - (mel_lengths.max() % model.decoder.r))
|
||||
) // model.decoder.r
|
||||
else:
|
||||
alignment_lengths = mel_lengths // model.decoder.r
|
||||
alignment_lengths = mel_lengths // model.decoder.r
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(postnet_output, decoder_output, mel_input,
|
||||
linear_input, stop_tokens, stop_targets,
|
||||
mel_lengths, decoder_backward_output,
|
||||
alignments, alignment_lengths, alignments_backward,
|
||||
text_lengths)
|
||||
loss_dict = criterion(
|
||||
postnet_output,
|
||||
decoder_output,
|
||||
mel_input,
|
||||
linear_input,
|
||||
stop_tokens,
|
||||
stop_targets,
|
||||
mel_lengths,
|
||||
decoder_backward_output,
|
||||
alignments,
|
||||
alignment_lengths,
|
||||
alignments_backward,
|
||||
text_lengths,
|
||||
)
|
||||
|
||||
# step time
|
||||
step_time = time.time() - start_time
|
||||
|
@ -371,14 +459,14 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
|
||||
# compute alignment score
|
||||
align_error = 1 - alignment_diagonal_score(alignments)
|
||||
loss_dict['align_error'] = align_error
|
||||
loss_dict["align_error"] = align_error
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict['postnet_loss'] = reduce_tensor(loss_dict['postnet_loss'].data, num_gpus)
|
||||
loss_dict['decoder_loss'] = reduce_tensor(loss_dict['decoder_loss'].data, num_gpus)
|
||||
loss_dict["postnet_loss"] = reduce_tensor(loss_dict["postnet_loss"].data, num_gpus)
|
||||
loss_dict["decoder_loss"] = reduce_tensor(loss_dict["decoder_loss"].data, num_gpus)
|
||||
if c.stopnet:
|
||||
loss_dict['stopnet_loss'] = reduce_tensor(loss_dict['stopnet_loss'].data, num_gpus)
|
||||
loss_dict["stopnet_loss"] = reduce_tensor(loss_dict["stopnet_loss"].data, num_gpus)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
|
@ -392,7 +480,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values['avg_' + key] = value
|
||||
update_train_values["avg_" + key] = value
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
if c.print_eval:
|
||||
|
@ -402,15 +490,17 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
# Diagnostic visualizations
|
||||
idx = np.random.randint(mel_input.shape[0])
|
||||
const_spec = postnet_output[idx].data.cpu().numpy()
|
||||
gt_spec = linear_input[idx].data.cpu().numpy() if c.model in [
|
||||
"Tacotron", "TacotronGST"
|
||||
] else mel_input[idx].data.cpu().numpy()
|
||||
gt_spec = (
|
||||
linear_input[idx].data.cpu().numpy()
|
||||
if c.model in ["Tacotron", "TacotronGST"]
|
||||
else mel_input[idx].data.cpu().numpy()
|
||||
)
|
||||
align_img = alignments[idx].data.cpu().numpy()
|
||||
|
||||
eval_figures = {
|
||||
"prediction": plot_spectrogram(const_spec, ap, output_fig=False),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||
"alignment": plot_alignment(align_img, output_fig=False)
|
||||
"alignment": plot_alignment(align_img, output_fig=False),
|
||||
}
|
||||
|
||||
# Sample audio
|
||||
|
@ -418,14 +508,13 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
eval_audio = ap.inv_spectrogram(const_spec.T)
|
||||
else:
|
||||
eval_audio = ap.inv_melspectrogram(const_spec.T)
|
||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
|
||||
c.audio["sample_rate"])
|
||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio}, c.audio["sample_rate"])
|
||||
|
||||
# Plot Validation Stats
|
||||
|
||||
if c.bidirectional_decoder or c.double_decoder_consistency:
|
||||
align_b_img = alignments_backward[idx].data.cpu().numpy()
|
||||
eval_figures['alignment2'] = plot_alignment(align_b_img, output_fig=False)
|
||||
eval_figures["alignment2"] = plot_alignment(align_b_img, output_fig=False)
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
tb_logger.tb_eval_figures(global_step, eval_figures)
|
||||
|
||||
|
@ -436,7 +525,7 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963."
|
||||
"Prior to November 22, 1963.",
|
||||
]
|
||||
else:
|
||||
with open(c.test_sentences_file, "r") as f:
|
||||
|
@ -447,13 +536,17 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
test_figures = {}
|
||||
print(" | > Synthesizing test sentences")
|
||||
speaker_id = 0 if c.use_speaker_embedding else None
|
||||
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding'] if c.use_external_speaker_embedding_file and c.use_speaker_embedding else None
|
||||
speaker_embedding = (
|
||||
speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping) - 1)]]["embedding"]
|
||||
if c.use_external_speaker_embedding_file and c.use_speaker_embedding
|
||||
else None
|
||||
)
|
||||
style_wav = c.get("gst_style_input")
|
||||
if style_wav is None and c.use_gst:
|
||||
# inicialize GST with zero dict.
|
||||
style_wav = {}
|
||||
print("WARNING: You don't provided a gst style wav, for this reason we use a zero tensor!")
|
||||
for i in range(c.gst['gst_style_tokens']):
|
||||
for i in range(c.gst["gst_style_tokens"]):
|
||||
style_wav[str(i)] = 0
|
||||
style_wav = c.get("gst_style_input")
|
||||
for idx, test_sentence in enumerate(test_sentences):
|
||||
|
@ -468,25 +561,22 @@ def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
|||
speaker_embedding=speaker_embedding,
|
||||
style_wav=style_wav,
|
||||
truncated=False,
|
||||
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
|
||||
enable_eos_bos_chars=c.enable_eos_bos_chars, # pylint: disable=unused-argument
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False)
|
||||
do_trim_silence=False,
|
||||
)
|
||||
|
||||
file_path = os.path.join(AUDIO_PATH, str(global_step))
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
file_path = os.path.join(file_path,
|
||||
"TestSentence_{}.wav".format(idx))
|
||||
file_path = os.path.join(file_path, "TestSentence_{}.wav".format(idx))
|
||||
ap.save_wav(wav, file_path)
|
||||
test_audios['{}-audio'.format(idx)] = wav
|
||||
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
|
||||
postnet_output, ap, output_fig=False)
|
||||
test_figures['{}-alignment'.format(idx)] = plot_alignment(
|
||||
alignment, output_fig=False)
|
||||
except: #pylint: disable=bare-except
|
||||
test_audios["{}-audio".format(idx)] = wav
|
||||
test_figures["{}-prediction".format(idx)] = plot_spectrogram(postnet_output, ap, output_fig=False)
|
||||
test_figures["{}-alignment".format(idx)] = plot_alignment(alignment, output_fig=False)
|
||||
except: # pylint: disable=bare-except
|
||||
print(" !! Error creating Test Sentence -", idx)
|
||||
traceback.print_exc()
|
||||
tb_logger.tb_test_audios(global_step, test_audios,
|
||||
c.audio['sample_rate'])
|
||||
tb_logger.tb_test_audios(global_step, test_audios, c.audio["sample_rate"])
|
||||
tb_logger.tb_test_figures(global_step, test_figures)
|
||||
return keep_avg.avg_values
|
||||
|
||||
|
@ -498,13 +588,12 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
ap = AudioProcessor(**c.audio)
|
||||
|
||||
# setup custom characters if set in config file.
|
||||
if 'characters' in c.keys():
|
||||
if "characters" in c.keys():
|
||||
symbols, phonemes = make_symbols(**c.characters)
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id,
|
||||
c.distributed["backend"], c.distributed["url"])
|
||||
init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"])
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
model_characters = phonemes if c.use_phonemes else symbols
|
||||
|
||||
|
@ -512,10 +601,10 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
meta_data_train, meta_data_eval = load_meta_data(c.datasets)
|
||||
|
||||
# set the portion of the data used for training
|
||||
if 'train_portion' in c.keys():
|
||||
meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)]
|
||||
if 'eval_portion' in c.keys():
|
||||
meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
|
||||
if "train_portion" in c.keys():
|
||||
meta_data_train = meta_data_train[: int(len(meta_data_train) * c.train_portion)]
|
||||
if "eval_portion" in c.keys():
|
||||
meta_data_eval = meta_data_eval[: int(len(meta_data_eval) * c.eval_portion)]
|
||||
|
||||
# parse speakers
|
||||
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH)
|
||||
|
@ -529,9 +618,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
params = set_weight_decay(model, c.wd)
|
||||
optimizer = RAdam(params, lr=c.lr, weight_decay=0)
|
||||
if c.stopnet and c.separate_stopnet:
|
||||
optimizer_st = RAdam(model.decoder.stopnet.parameters(),
|
||||
lr=c.lr,
|
||||
weight_decay=0)
|
||||
optimizer_st = RAdam(model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
|
||||
else:
|
||||
optimizer_st = None
|
||||
|
||||
|
@ -539,13 +626,13 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
criterion = TacotronLoss(c, stopnet_pos_weight=c.stopnet_pos_weight, ga_sigma=0.4)
|
||||
if args.restore_path:
|
||||
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
|
||||
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||
checkpoint = torch.load(args.restore_path, map_location="cpu")
|
||||
try:
|
||||
print(" > Restoring Model...")
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
# optimizer restore
|
||||
print(" > Restoring Optimizer...")
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
if "scaler" in checkpoint and c.mixed_precision:
|
||||
print(" > Restoring AMP Scaler...")
|
||||
scaler.load_state_dict(checkpoint["scaler"])
|
||||
|
@ -554,17 +641,16 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
except (KeyError, RuntimeError):
|
||||
print(" > Partial model initialization...")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
||||
# torch.save(model_dict, os.path.join(OUT_PATH, 'state_dict.pt'))
|
||||
# print("State Dict saved for debug in: ", os.path.join(OUT_PATH, 'state_dict.pt'))
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group['lr'] = c.lr
|
||||
print(" > Model restored from step %d" % checkpoint['step'],
|
||||
flush=True)
|
||||
args.restore_step = checkpoint['step']
|
||||
group["lr"] = c.lr
|
||||
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
|
@ -577,9 +663,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
model = apply_gradient_allreduce(model)
|
||||
|
||||
if c.noam_schedule:
|
||||
scheduler = NoamLR(optimizer,
|
||||
warmup_steps=c.warmup_steps,
|
||||
last_epoch=args.restore_step - 1)
|
||||
scheduler = NoamLR(optimizer, warmup_steps=c.warmup_steps, last_epoch=args.restore_step - 1)
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
|
@ -587,22 +671,17 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
print("\n > Model has {} parameters".format(num_params), flush=True)
|
||||
|
||||
if args.restore_step == 0 or not args.best_path:
|
||||
best_loss = float('inf')
|
||||
best_loss = float("inf")
|
||||
print(" > Starting with inf best loss.")
|
||||
else:
|
||||
print(" > Restoring best loss from "
|
||||
f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path,
|
||||
map_location='cpu')['model_loss']
|
||||
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {best_loss}.")
|
||||
keep_all_best = c.get('keep_all_best', False)
|
||||
keep_after = c.get('keep_after', 10000) # void if keep_all_best False
|
||||
keep_all_best = c.get("keep_all_best", False)
|
||||
keep_after = c.get("keep_after", 10000) # void if keep_all_best False
|
||||
|
||||
# define data loaders
|
||||
train_loader = setup_loader(ap,
|
||||
model.decoder.r,
|
||||
is_val=False,
|
||||
verbose=True)
|
||||
train_loader = setup_loader(ap, model.decoder.r, is_val=False, verbose=True)
|
||||
eval_loader = setup_loader(ap, model.decoder.r, is_val=True)
|
||||
|
||||
global_step = args.restore_step
|
||||
|
@ -617,28 +696,29 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
model.decoder_backward.set_r(r)
|
||||
train_loader.dataset.outputs_per_step = r
|
||||
eval_loader.dataset.outputs_per_step = r
|
||||
train_loader = setup_loader(ap,
|
||||
model.decoder.r,
|
||||
is_val=False,
|
||||
dataset=train_loader.dataset)
|
||||
eval_loader = setup_loader(ap,
|
||||
model.decoder.r,
|
||||
is_val=True,
|
||||
dataset=eval_loader.dataset)
|
||||
train_loader = setup_loader(ap, model.decoder.r, is_val=False, dataset=train_loader.dataset)
|
||||
eval_loader = setup_loader(ap, model.decoder.r, is_val=True, dataset=eval_loader.dataset)
|
||||
print("\n > Number of output frames:", model.decoder.r)
|
||||
# train one epoch
|
||||
train_avg_loss_dict, global_step = train(train_loader, model,
|
||||
criterion, optimizer,
|
||||
optimizer_st, scheduler, ap,
|
||||
global_step, epoch, scaler,
|
||||
scaler_st)
|
||||
train_avg_loss_dict, global_step = train(
|
||||
train_loader,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
optimizer_st,
|
||||
scheduler,
|
||||
ap,
|
||||
global_step,
|
||||
epoch,
|
||||
scaler,
|
||||
scaler_st,
|
||||
)
|
||||
# eval one epoch
|
||||
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap,
|
||||
global_step, epoch)
|
||||
eval_avg_loss_dict = evaluate(eval_loader, model, criterion, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = train_avg_loss_dict['avg_postnet_loss']
|
||||
target_loss = train_avg_loss_dict["avg_postnet_loss"]
|
||||
if c.run_eval:
|
||||
target_loss = eval_avg_loss_dict['avg_postnet_loss']
|
||||
target_loss = eval_avg_loss_dict["avg_postnet_loss"]
|
||||
best_loss = save_best_model(
|
||||
target_loss,
|
||||
best_loss,
|
||||
|
@ -651,14 +731,13 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
model_characters,
|
||||
keep_all_best=keep_all_best,
|
||||
keep_after=keep_after,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, model_class='tts')
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="tts")
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -12,8 +12,7 @@ import torch
|
|||
from torch.utils.data import DataLoader
|
||||
from TTS.utils.arguments import parse_arguments, process_args
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
|
||||
|
||||
from TTS.utils.radam import RAdam
|
||||
|
||||
|
@ -21,8 +20,7 @@ from TTS.utils.training import setup_torch_training_env
|
|||
from TTS.vocoder.datasets.gan_dataset import GANDataset
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.layers.losses import DiscriminatorLoss, GeneratorLoss
|
||||
from TTS.vocoder.utils.generic_utils import (plot_results, setup_discriminator,
|
||||
setup_generator)
|
||||
from TTS.vocoder.utils.generic_utils import plot_results, setup_discriminator, setup_generator
|
||||
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
||||
|
||||
# DISTRIBUTED
|
||||
|
@ -36,27 +34,30 @@ use_cuda, num_gpus = setup_torch_training_env(True, True)
|
|||
def setup_loader(ap, is_val=False, verbose=False):
|
||||
loader = None
|
||||
if not is_val or c.run_eval:
|
||||
dataset = GANDataset(ap=ap,
|
||||
items=eval_data if is_val else train_data,
|
||||
seq_len=c.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=c.pad_short,
|
||||
conv_pad=c.conv_pad,
|
||||
is_training=not is_val,
|
||||
return_segments=not is_val,
|
||||
use_noise_augment=c.use_noise_augment,
|
||||
use_cache=c.use_cache,
|
||||
verbose=verbose)
|
||||
dataset = GANDataset(
|
||||
ap=ap,
|
||||
items=eval_data if is_val else train_data,
|
||||
seq_len=c.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=c.pad_short,
|
||||
conv_pad=c.conv_pad,
|
||||
is_training=not is_val,
|
||||
return_segments=not is_val,
|
||||
use_noise_augment=c.use_noise_augment,
|
||||
use_cache=c.use_cache,
|
||||
verbose=verbose,
|
||||
)
|
||||
dataset.shuffle_mapping()
|
||||
sampler = DistributedSampler(dataset, shuffle=True) if num_gpus > 1 else None
|
||||
loader = DataLoader(dataset,
|
||||
batch_size=1 if is_val else c.batch_size,
|
||||
shuffle=num_gpus == 0,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=c.num_val_loader_workers
|
||||
if is_val else c.num_loader_workers,
|
||||
pin_memory=False)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1 if is_val else c.batch_size,
|
||||
shuffle=num_gpus == 0,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
return loader
|
||||
|
||||
|
||||
|
@ -83,16 +84,26 @@ def format_data(data):
|
|||
return co, x, None, None
|
||||
|
||||
|
||||
def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
||||
scheduler_G, scheduler_D, ap, global_step, epoch):
|
||||
def train(
|
||||
model_G,
|
||||
criterion_G,
|
||||
optimizer_G,
|
||||
model_D,
|
||||
criterion_D,
|
||||
optimizer_D,
|
||||
scheduler_G,
|
||||
scheduler_D,
|
||||
ap,
|
||||
global_step,
|
||||
epoch,
|
||||
):
|
||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||
model_G.train()
|
||||
model_D.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(
|
||||
len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
end_time = time.time()
|
||||
|
@ -148,16 +159,14 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
scores_fake = D_out_fake
|
||||
|
||||
# compute losses
|
||||
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
|
||||
feats_real, y_hat_sub, y_G_sub)
|
||||
loss_G = loss_G_dict['G_loss']
|
||||
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, feats_real, y_hat_sub, y_G_sub)
|
||||
loss_G = loss_G_dict["G_loss"]
|
||||
|
||||
# optimizer generator
|
||||
optimizer_G.zero_grad()
|
||||
loss_G.backward()
|
||||
if c.gen_clip_grad > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model_G.parameters(),
|
||||
c.gen_clip_grad)
|
||||
torch.nn.utils.clip_grad_norm_(model_G.parameters(), c.gen_clip_grad)
|
||||
optimizer_G.step()
|
||||
if scheduler_G is not None:
|
||||
scheduler_G.step()
|
||||
|
@ -202,14 +211,13 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
|
||||
# compute losses
|
||||
loss_D_dict = criterion_D(scores_fake, scores_real)
|
||||
loss_D = loss_D_dict['D_loss']
|
||||
loss_D = loss_D_dict["D_loss"]
|
||||
|
||||
# optimizer discriminator
|
||||
optimizer_D.zero_grad()
|
||||
loss_D.backward()
|
||||
if c.disc_clip_grad > 0:
|
||||
torch.nn.utils.clip_grad_norm_(model_D.parameters(),
|
||||
c.disc_clip_grad)
|
||||
torch.nn.utils.clip_grad_norm_(model_D.parameters(), c.disc_clip_grad)
|
||||
optimizer_D.step()
|
||||
if scheduler_D is not None:
|
||||
scheduler_D.step()
|
||||
|
@ -224,36 +232,31 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
epoch_time += step_time
|
||||
|
||||
# get current learning rates
|
||||
current_lr_G = list(optimizer_G.param_groups)[0]['lr']
|
||||
current_lr_D = list(optimizer_D.param_groups)[0]['lr']
|
||||
current_lr_G = list(optimizer_G.param_groups)[0]["lr"]
|
||||
current_lr_D = list(optimizer_D.param_groups)[0]["lr"]
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values['avg_' + key] = value
|
||||
update_train_values['avg_loader_time'] = loader_time
|
||||
update_train_values['avg_step_time'] = step_time
|
||||
update_train_values["avg_" + key] = value
|
||||
update_train_values["avg_loader_time"] = loader_time
|
||||
update_train_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
# print training stats
|
||||
if global_step % c.print_step == 0:
|
||||
log_dict = {
|
||||
'step_time': [step_time, 2],
|
||||
'loader_time': [loader_time, 4],
|
||||
"step_time": [step_time, 2],
|
||||
"loader_time": [loader_time, 4],
|
||||
"current_lr_G": current_lr_G,
|
||||
"current_lr_D": current_lr_D
|
||||
"current_lr_D": current_lr_D,
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
||||
log_dict, loss_dict, keep_avg.avg_values)
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# plot step stats
|
||||
if global_step % 10 == 0:
|
||||
iter_stats = {
|
||||
"lr_G": current_lr_G,
|
||||
"lr_D": current_lr_D,
|
||||
"step_time": step_time
|
||||
}
|
||||
iter_stats = {"lr_G": current_lr_G, "lr_D": current_lr_D, "step_time": step_time}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
|
||||
|
@ -261,27 +264,26 @@ def train(model_G, criterion_G, optimizer_G, model_D, criterion_D, optimizer_D,
|
|||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model_G,
|
||||
optimizer_G,
|
||||
scheduler_G,
|
||||
model_D,
|
||||
optimizer_D,
|
||||
scheduler_D,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=loss_dict)
|
||||
save_checkpoint(
|
||||
model_G,
|
||||
optimizer_G,
|
||||
scheduler_G,
|
||||
model_D,
|
||||
optimizer_D,
|
||||
scheduler_D,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=loss_dict,
|
||||
)
|
||||
|
||||
# compute spectrograms
|
||||
figures = plot_results(y_hat_vis, y_G, ap, global_step,
|
||||
'train')
|
||||
figures = plot_results(y_hat_vis, y_G, ap, global_step, "train")
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
sample_voice = y_hat_vis[0].squeeze(0).detach().cpu().numpy()
|
||||
tb_logger.tb_train_audios(global_step,
|
||||
{'train/audio': sample_voice},
|
||||
c.audio["sample_rate"])
|
||||
tb_logger.tb_train_audios(global_step, {"train/audio": sample_voice}, c.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
# print epoch stats
|
||||
|
@ -356,8 +358,7 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
|
|||
feats_fake, feats_real = None, None
|
||||
|
||||
# compute losses
|
||||
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake,
|
||||
feats_real, y_hat_sub, y_G_sub)
|
||||
loss_G_dict = criterion_G(y_hat, y_G, scores_fake, feats_fake, feats_real, y_hat_sub, y_G_sub)
|
||||
|
||||
loss_dict = dict()
|
||||
for key, value in loss_G_dict.items():
|
||||
|
@ -413,9 +414,9 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
|
|||
# update avg stats
|
||||
update_eval_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_eval_values['avg_' + key] = value
|
||||
update_eval_values['avg_loader_time'] = loader_time
|
||||
update_eval_values['avg_step_time'] = step_time
|
||||
update_eval_values["avg_" + key] = value
|
||||
update_eval_values["avg_loader_time"] = loader_time
|
||||
update_eval_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_eval_values)
|
||||
|
||||
# print eval stats
|
||||
|
@ -424,13 +425,12 @@ def evaluate(model_G, criterion_G, model_D, criterion_D, ap, global_step, epoch)
|
|||
|
||||
if args.rank == 0:
|
||||
# compute spectrograms
|
||||
figures = plot_results(y_hat, y_G, ap, global_step, 'eval')
|
||||
figures = plot_results(y_hat, y_G, ap, global_step, "eval")
|
||||
tb_logger.tb_eval_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
sample_voice = y_hat[0].squeeze(0).detach().cpu().numpy()
|
||||
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
|
||||
c.audio["sample_rate"])
|
||||
tb_logger.tb_eval_audios(global_step, {"eval/audio": sample_voice}, c.audio["sample_rate"])
|
||||
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
|
||||
|
@ -446,8 +446,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
print(f" > Loading wavs from: {c.data_path}")
|
||||
if c.feature_path is not None:
|
||||
print(f" > Loading features from: {c.feature_path}")
|
||||
eval_data, train_data = load_wav_feat_data(
|
||||
c.data_path, c.feature_path, c.eval_split_size)
|
||||
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
|
||||
else:
|
||||
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
|
||||
|
||||
|
@ -456,8 +455,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id,
|
||||
c.distributed["backend"], c.distributed["url"])
|
||||
init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"])
|
||||
|
||||
# setup models
|
||||
model_gen = setup_generator(c)
|
||||
|
@ -465,21 +463,17 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
# setup optimizers
|
||||
optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0)
|
||||
optimizer_disc = RAdam(model_disc.parameters(),
|
||||
lr=c.lr_disc,
|
||||
weight_decay=0)
|
||||
optimizer_disc = RAdam(model_disc.parameters(), lr=c.lr_disc, weight_decay=0)
|
||||
|
||||
# schedulers
|
||||
scheduler_gen = None
|
||||
scheduler_disc = None
|
||||
if 'lr_scheduler_gen' in c:
|
||||
if "lr_scheduler_gen" in c:
|
||||
scheduler_gen = getattr(torch.optim.lr_scheduler, c.lr_scheduler_gen)
|
||||
scheduler_gen = scheduler_gen(
|
||||
optimizer_gen, **c.lr_scheduler_gen_params)
|
||||
if 'lr_scheduler_disc' in c:
|
||||
scheduler_gen = scheduler_gen(optimizer_gen, **c.lr_scheduler_gen_params)
|
||||
if "lr_scheduler_disc" in c:
|
||||
scheduler_disc = getattr(torch.optim.lr_scheduler, c.lr_scheduler_disc)
|
||||
scheduler_disc = scheduler_disc(
|
||||
optimizer_disc, **c.lr_scheduler_disc_params)
|
||||
scheduler_disc = scheduler_disc(optimizer_disc, **c.lr_scheduler_disc_params)
|
||||
|
||||
# setup criterion
|
||||
criterion_gen = GeneratorLoss(c)
|
||||
|
@ -487,47 +481,46 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
if args.restore_path:
|
||||
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
|
||||
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||
checkpoint = torch.load(args.restore_path, map_location="cpu")
|
||||
try:
|
||||
print(" > Restoring Generator Model...")
|
||||
model_gen.load_state_dict(checkpoint['model'])
|
||||
model_gen.load_state_dict(checkpoint["model"])
|
||||
print(" > Restoring Generator Optimizer...")
|
||||
optimizer_gen.load_state_dict(checkpoint['optimizer'])
|
||||
optimizer_gen.load_state_dict(checkpoint["optimizer"])
|
||||
print(" > Restoring Discriminator Model...")
|
||||
model_disc.load_state_dict(checkpoint['model_disc'])
|
||||
model_disc.load_state_dict(checkpoint["model_disc"])
|
||||
print(" > Restoring Discriminator Optimizer...")
|
||||
optimizer_disc.load_state_dict(checkpoint['optimizer_disc'])
|
||||
if 'scheduler' in checkpoint and scheduler_gen is not None:
|
||||
optimizer_disc.load_state_dict(checkpoint["optimizer_disc"])
|
||||
if "scheduler" in checkpoint and scheduler_gen is not None:
|
||||
print(" > Restoring Generator LR Scheduler...")
|
||||
scheduler_gen.load_state_dict(checkpoint['scheduler'])
|
||||
scheduler_gen.load_state_dict(checkpoint["scheduler"])
|
||||
# NOTE: Not sure if necessary
|
||||
scheduler_gen.optimizer = optimizer_gen
|
||||
if 'scheduler_disc' in checkpoint and scheduler_disc is not None:
|
||||
if "scheduler_disc" in checkpoint and scheduler_disc is not None:
|
||||
print(" > Restoring Discriminator LR Scheduler...")
|
||||
scheduler_disc.load_state_dict(checkpoint['scheduler_disc'])
|
||||
scheduler_disc.load_state_dict(checkpoint["scheduler_disc"])
|
||||
scheduler_disc.optimizer = optimizer_disc
|
||||
except RuntimeError:
|
||||
# restore only matching layers.
|
||||
print(" > Partial model initialization...")
|
||||
model_dict = model_gen.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
||||
model_gen.load_state_dict(model_dict)
|
||||
|
||||
model_dict = model_disc.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint['model_disc'], c)
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model_disc"], c)
|
||||
model_disc.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
# reset lr if not countinuining training.
|
||||
for group in optimizer_gen.param_groups:
|
||||
group['lr'] = c.lr_gen
|
||||
group["lr"] = c.lr_gen
|
||||
|
||||
for group in optimizer_disc.param_groups:
|
||||
group['lr'] = c.lr_disc
|
||||
group["lr"] = c.lr_disc
|
||||
|
||||
print(f" > Model restored from step {checkpoint['step']:d}",
|
||||
flush=True)
|
||||
args.restore_step = checkpoint['step']
|
||||
print(f" > Model restored from step {checkpoint['step']:d}", flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
|
@ -548,50 +541,55 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
print(" > Discriminator has {} parameters".format(num_params), flush=True)
|
||||
|
||||
if args.restore_step == 0 or not args.best_path:
|
||||
best_loss = float('inf')
|
||||
best_loss = float("inf")
|
||||
print(" > Starting with inf best loss.")
|
||||
else:
|
||||
print(" > Restoring best loss from "
|
||||
f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path,
|
||||
map_location='cpu')['model_loss']
|
||||
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with best loss of {best_loss}.")
|
||||
keep_all_best = c.get('keep_all_best', False)
|
||||
keep_after = c.get('keep_after', 10000) # void if keep_all_best False
|
||||
keep_all_best = c.get("keep_all_best", False)
|
||||
keep_after = c.get("keep_after", 10000) # void if keep_all_best False
|
||||
|
||||
global_step = args.restore_step
|
||||
for epoch in range(0, c.epochs):
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
_, global_step = train(model_gen, criterion_gen, optimizer_gen,
|
||||
model_disc, criterion_disc, optimizer_disc,
|
||||
scheduler_gen, scheduler_disc, ap, global_step,
|
||||
epoch)
|
||||
eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc,
|
||||
criterion_disc, ap,
|
||||
global_step, epoch)
|
||||
_, global_step = train(
|
||||
model_gen,
|
||||
criterion_gen,
|
||||
optimizer_gen,
|
||||
model_disc,
|
||||
criterion_disc,
|
||||
optimizer_disc,
|
||||
scheduler_gen,
|
||||
scheduler_disc,
|
||||
ap,
|
||||
global_step,
|
||||
epoch,
|
||||
)
|
||||
eval_avg_loss_dict = evaluate(model_gen, criterion_gen, model_disc, criterion_disc, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = eval_avg_loss_dict[c.target_loss]
|
||||
best_loss = save_best_model(target_loss,
|
||||
best_loss,
|
||||
model_gen,
|
||||
optimizer_gen,
|
||||
scheduler_gen,
|
||||
model_disc,
|
||||
optimizer_disc,
|
||||
scheduler_disc,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
keep_all_best=keep_all_best,
|
||||
keep_after=keep_after,
|
||||
model_losses=eval_avg_loss_dict,
|
||||
)
|
||||
best_loss = save_best_model(
|
||||
target_loss,
|
||||
best_loss,
|
||||
model_gen,
|
||||
optimizer_gen,
|
||||
scheduler_gen,
|
||||
model_disc,
|
||||
optimizer_disc,
|
||||
scheduler_disc,
|
||||
global_step,
|
||||
epoch,
|
||||
OUT_PATH,
|
||||
keep_all_best=keep_all_best,
|
||||
keep_after=keep_after,
|
||||
model_losses=eval_avg_loss_dict,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, model_class='vocoder')
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="vocoder")
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -8,6 +8,7 @@ import traceback
|
|||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
# DISTRIBUTED
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.optim import Adam
|
||||
|
@ -16,8 +17,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
from TTS.utils.arguments import parse_arguments, process_args
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.distribute import init_distributed
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.generic_utils import KeepAverage, count_parameters, remove_experiment_folder, set_init_dict
|
||||
from TTS.utils.training import setup_torch_training_env
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
||||
|
@ -31,27 +31,29 @@ def setup_loader(ap, is_val=False, verbose=False):
|
|||
if is_val and not c.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
dataset = WaveGradDataset(ap=ap,
|
||||
items=eval_data if is_val else train_data,
|
||||
seq_len=c.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=c.pad_short,
|
||||
conv_pad=c.conv_pad,
|
||||
is_training=not is_val,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
use_cache=c.use_cache,
|
||||
verbose=verbose)
|
||||
dataset = WaveGradDataset(
|
||||
ap=ap,
|
||||
items=eval_data if is_val else train_data,
|
||||
seq_len=c.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=c.pad_short,
|
||||
conv_pad=c.conv_pad,
|
||||
is_training=not is_val,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
use_cache=c.use_cache,
|
||||
verbose=verbose,
|
||||
)
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(dataset,
|
||||
batch_size=c.batch_size,
|
||||
shuffle=num_gpus <= 1,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=c.num_val_loader_workers
|
||||
if is_val else c.num_loader_workers,
|
||||
pin_memory=False)
|
||||
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=c.batch_size,
|
||||
shuffle=num_gpus <= 1,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
|
||||
pin_memory=False,
|
||||
)
|
||||
|
||||
return loader
|
||||
|
||||
|
@ -77,24 +79,21 @@ def format_test_data(data):
|
|||
return m, x
|
||||
|
||||
|
||||
def train(model, criterion, optimizer, scheduler, scaler, ap, global_step,
|
||||
epoch):
|
||||
def train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch):
|
||||
data_loader = setup_loader(ap, is_val=False, verbose=(epoch == 0))
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(
|
||||
len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
end_time = time.time()
|
||||
c_logger.print_train_start()
|
||||
# setup noise schedule
|
||||
noise_schedule = c['train_noise_schedule']
|
||||
betas = np.linspace(noise_schedule['min_val'], noise_schedule['max_val'],
|
||||
noise_schedule['num_steps'])
|
||||
if hasattr(model, 'module'):
|
||||
noise_schedule = c["train_noise_schedule"]
|
||||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
if hasattr(model, "module"):
|
||||
model.module.compute_noise_level(betas)
|
||||
else:
|
||||
model.compute_noise_level(betas)
|
||||
|
@ -109,7 +108,7 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step,
|
|||
|
||||
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
||||
# compute noisy input
|
||||
if hasattr(model, 'module'):
|
||||
if hasattr(model, "module"):
|
||||
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
||||
else:
|
||||
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
||||
|
@ -119,11 +118,11 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step,
|
|||
|
||||
# compute losses
|
||||
loss = criterion(noise, noise_hat)
|
||||
loss_wavegrad_dict = {'wavegrad_loss': loss}
|
||||
loss_wavegrad_dict = {"wavegrad_loss": loss}
|
||||
|
||||
# check nan loss
|
||||
if torch.isnan(loss).any():
|
||||
raise RuntimeError(f'Detected NaN loss at step {global_step}.')
|
||||
raise RuntimeError(f"Detected NaN loss at step {global_step}.")
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
@ -131,14 +130,12 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step,
|
|||
if c.mixed_precision:
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||
c.clip_grad)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.clip_grad)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss.backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||
c.clip_grad)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), c.clip_grad)
|
||||
optimizer.step()
|
||||
|
||||
# schedule update
|
||||
|
@ -158,35 +155,30 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step,
|
|||
epoch_time += step_time
|
||||
|
||||
# get current learning rates
|
||||
current_lr = list(optimizer.param_groups)[0]['lr']
|
||||
current_lr = list(optimizer.param_groups)[0]["lr"]
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values['avg_' + key] = value
|
||||
update_train_values['avg_loader_time'] = loader_time
|
||||
update_train_values['avg_step_time'] = step_time
|
||||
update_train_values["avg_" + key] = value
|
||||
update_train_values["avg_loader_time"] = loader_time
|
||||
update_train_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
# print training stats
|
||||
if global_step % c.print_step == 0:
|
||||
log_dict = {
|
||||
'step_time': [step_time, 2],
|
||||
'loader_time': [loader_time, 4],
|
||||
"step_time": [step_time, 2],
|
||||
"loader_time": [loader_time, 4],
|
||||
"current_lr": current_lr,
|
||||
"grad_norm": grad_norm.item()
|
||||
"grad_norm": grad_norm.item(),
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
||||
log_dict, loss_dict, keep_avg.avg_values)
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step, log_dict, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# plot step stats
|
||||
if global_step % 10 == 0:
|
||||
iter_stats = {
|
||||
"lr": current_lr,
|
||||
"grad_norm": grad_norm.item(),
|
||||
"step_time": step_time
|
||||
}
|
||||
iter_stats = {"lr": current_lr, "grad_norm": grad_norm.item(), "step_time": step_time}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
|
||||
|
@ -205,7 +197,7 @@ def train(model, criterion, optimizer, scheduler, scaler, ap, global_step,
|
|||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=loss_dict,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None,
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
|
@ -242,19 +234,17 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
global_step += 1
|
||||
|
||||
# compute noisy input
|
||||
if hasattr(model, 'module'):
|
||||
if hasattr(model, "module"):
|
||||
noise, x_noisy, noise_scale = model.module.compute_y_n(x)
|
||||
else:
|
||||
noise, x_noisy, noise_scale = model.compute_y_n(x)
|
||||
|
||||
|
||||
# forward pass
|
||||
noise_hat = model(x_noisy, m, noise_scale)
|
||||
|
||||
# compute losses
|
||||
loss = criterion(noise, noise_hat)
|
||||
loss_wavegrad_dict = {'wavegrad_loss': loss}
|
||||
|
||||
loss_wavegrad_dict = {"wavegrad_loss": loss}
|
||||
|
||||
loss_dict = dict()
|
||||
for key, value in loss_wavegrad_dict.items():
|
||||
|
@ -269,9 +259,9 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
# update avg stats
|
||||
update_eval_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_eval_values['avg_' + key] = value
|
||||
update_eval_values['avg_loader_time'] = loader_time
|
||||
update_eval_values['avg_step_time'] = step_time
|
||||
update_eval_values["avg_" + key] = value
|
||||
update_eval_values["avg_loader_time"] = loader_time
|
||||
update_eval_values["avg_step_time"] = step_time
|
||||
keep_avg.update_values(update_eval_values)
|
||||
|
||||
# print eval stats
|
||||
|
@ -284,11 +274,9 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
m, x = format_test_data(samples[0])
|
||||
|
||||
# setup noise schedule and inference
|
||||
noise_schedule = c['test_noise_schedule']
|
||||
betas = np.linspace(noise_schedule['min_val'],
|
||||
noise_schedule['max_val'],
|
||||
noise_schedule['num_steps'])
|
||||
if hasattr(model, 'module'):
|
||||
noise_schedule = c["test_noise_schedule"]
|
||||
betas = np.linspace(noise_schedule["min_val"], noise_schedule["max_val"], noise_schedule["num_steps"])
|
||||
if hasattr(model, "module"):
|
||||
model.module.compute_noise_level(betas)
|
||||
# compute voice
|
||||
x_pred = model.module.inference(m)
|
||||
|
@ -298,13 +286,12 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
x_pred = model.inference(m)
|
||||
|
||||
# compute spectrograms
|
||||
figures = plot_results(x_pred, x, ap, global_step, 'eval')
|
||||
figures = plot_results(x_pred, x, ap, global_step, "eval")
|
||||
tb_logger.tb_eval_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
sample_voice = x_pred[0].squeeze(0).detach().cpu().numpy()
|
||||
tb_logger.tb_eval_audios(global_step, {'eval/audio': sample_voice},
|
||||
c.audio["sample_rate"])
|
||||
tb_logger.tb_eval_audios(global_step, {"eval/audio": sample_voice}, c.audio["sample_rate"])
|
||||
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
data_loader.dataset.return_segments = True
|
||||
|
@ -318,8 +305,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
print(f" > Loading wavs from: {c.data_path}")
|
||||
if c.feature_path is not None:
|
||||
print(f" > Loading features from: {c.feature_path}")
|
||||
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path,
|
||||
c.eval_split_size)
|
||||
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
|
||||
else:
|
||||
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
|
||||
|
||||
|
@ -328,8 +314,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id,
|
||||
c.distributed["backend"], c.distributed["url"])
|
||||
init_distributed(args.rank, num_gpus, args.group_id, c.distributed["backend"], c.distributed["url"])
|
||||
|
||||
# setup models
|
||||
model = setup_generator(c)
|
||||
|
@ -342,7 +327,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
# schedulers
|
||||
scheduler = None
|
||||
if 'lr_scheduler' in c:
|
||||
if "lr_scheduler" in c:
|
||||
scheduler = getattr(torch.optim.lr_scheduler, c.lr_scheduler)
|
||||
scheduler = scheduler(optimizer, **c.lr_scheduler_params)
|
||||
|
||||
|
@ -355,15 +340,15 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
if args.restore_path:
|
||||
print(f" > Restoring from {os.path.basename(args.restore_path)}...")
|
||||
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||
checkpoint = torch.load(args.restore_path, map_location="cpu")
|
||||
try:
|
||||
print(" > Restoring Model...")
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
print(" > Restoring Optimizer...")
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
if 'scheduler' in checkpoint:
|
||||
optimizer.load_state_dict(checkpoint["optimizer"])
|
||||
if "scheduler" in checkpoint:
|
||||
print(" > Restoring LR Scheduler...")
|
||||
scheduler.load_state_dict(checkpoint['scheduler'])
|
||||
scheduler.load_state_dict(checkpoint["scheduler"])
|
||||
# NOTE: Not sure if necessary
|
||||
scheduler.optimizer = optimizer
|
||||
if "scaler" in checkpoint and c.mixed_precision:
|
||||
|
@ -373,17 +358,16 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
# retore only matching layers.
|
||||
print(" > Partial model initialization...")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
# reset lr if not countinuining training.
|
||||
for group in optimizer.param_groups:
|
||||
group['lr'] = c.lr
|
||||
group["lr"] = c.lr
|
||||
|
||||
print(" > Model restored from step %d" % checkpoint['step'],
|
||||
flush=True)
|
||||
args.restore_step = checkpoint['step']
|
||||
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
|
@ -395,22 +379,19 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
print(" > WaveGrad has {} parameters".format(num_params), flush=True)
|
||||
|
||||
if args.restore_step == 0 or not args.best_path:
|
||||
best_loss = float('inf')
|
||||
best_loss = float("inf")
|
||||
print(" > Starting with inf best loss.")
|
||||
else:
|
||||
print(" > Restoring best loss from "
|
||||
f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path,
|
||||
map_location='cpu')['model_loss']
|
||||
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {best_loss}.")
|
||||
keep_all_best = c.get('keep_all_best', False)
|
||||
keep_after = c.get('keep_after', 10000) # void if keep_all_best False
|
||||
keep_all_best = c.get("keep_all_best", False)
|
||||
keep_after = c.get("keep_after", 10000) # void if keep_all_best False
|
||||
|
||||
global_step = args.restore_step
|
||||
for epoch in range(0, c.epochs):
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
_, global_step = train(model, criterion, optimizer, scheduler, scaler,
|
||||
ap, global_step, epoch)
|
||||
_, global_step = train(model, criterion, optimizer, scheduler, scaler, ap, global_step, epoch)
|
||||
eval_avg_loss_dict = evaluate(model, criterion, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = eval_avg_loss_dict[c.target_loss]
|
||||
|
@ -429,14 +410,13 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
keep_all_best=keep_all_best,
|
||||
keep_after=keep_after,
|
||||
model_losses=eval_avg_loss_dict,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, model_class='vocoder')
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="vocoder")
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -24,10 +24,7 @@ from TTS.utils.generic_utils import (
|
|||
set_init_dict,
|
||||
)
|
||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||
from TTS.vocoder.datasets.preprocess import (
|
||||
load_wav_data,
|
||||
load_wav_feat_data
|
||||
)
|
||||
from TTS.vocoder.datasets.preprocess import load_wav_data, load_wav_feat_data
|
||||
from TTS.vocoder.utils.distribution import discretized_mix_logistic_loss, gaussian_loss
|
||||
from TTS.vocoder.utils.generic_utils import setup_generator
|
||||
from TTS.vocoder.utils.io import save_best_model, save_checkpoint
|
||||
|
@ -40,26 +37,26 @@ def setup_loader(ap, is_val=False, verbose=False):
|
|||
if is_val and not c.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
dataset = WaveRNNDataset(ap=ap,
|
||||
items=eval_data if is_val else train_data,
|
||||
seq_len=c.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad=c.padding,
|
||||
mode=c.mode,
|
||||
mulaw=c.mulaw,
|
||||
is_training=not is_val,
|
||||
verbose=verbose,
|
||||
)
|
||||
dataset = WaveRNNDataset(
|
||||
ap=ap,
|
||||
items=eval_data if is_val else train_data,
|
||||
seq_len=c.seq_len,
|
||||
hop_len=ap.hop_length,
|
||||
pad=c.padding,
|
||||
mode=c.mode,
|
||||
mulaw=c.mulaw,
|
||||
is_training=not is_val,
|
||||
verbose=verbose,
|
||||
)
|
||||
# sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(dataset,
|
||||
shuffle=True,
|
||||
collate_fn=dataset.collate,
|
||||
batch_size=c.batch_size,
|
||||
num_workers=c.num_val_loader_workers
|
||||
if is_val
|
||||
else c.num_loader_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
collate_fn=dataset.collate,
|
||||
batch_size=c.batch_size,
|
||||
num_workers=c.num_val_loader_workers if is_val else c.num_loader_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
return loader
|
||||
|
||||
|
||||
|
@ -85,8 +82,7 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
|
|||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(len(data_loader.dataset) /
|
||||
(c.batch_size * num_gpus))
|
||||
batch_n_iter = int(len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
end_time = time.time()
|
||||
|
@ -114,8 +110,7 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
|
|||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
if c.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), c.grad_clip)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
|
@ -132,8 +127,7 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
|
|||
raise RuntimeError(" [!] None loss. Exiting ...")
|
||||
loss.backward()
|
||||
if c.grad_clip > 0:
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
model.parameters(), c.grad_clip)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), c.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
if scheduler is not None:
|
||||
|
@ -156,17 +150,19 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
|
|||
|
||||
# print training stats
|
||||
if global_step % c.print_step == 0:
|
||||
log_dict = {"step_time": [step_time, 2],
|
||||
"loader_time": [loader_time, 4],
|
||||
"current_lr": cur_lr,
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter,
|
||||
num_iter,
|
||||
global_step,
|
||||
log_dict,
|
||||
loss_dict,
|
||||
keep_avg.avg_values,
|
||||
)
|
||||
log_dict = {
|
||||
"step_time": [step_time, 2],
|
||||
"loader_time": [loader_time, 4],
|
||||
"current_lr": cur_lr,
|
||||
}
|
||||
c_logger.print_train_step(
|
||||
batch_n_iter,
|
||||
num_iter,
|
||||
global_step,
|
||||
log_dict,
|
||||
loss_dict,
|
||||
keep_avg.avg_values,
|
||||
)
|
||||
|
||||
# plot step stats
|
||||
if global_step % 10 == 0:
|
||||
|
@ -189,36 +185,36 @@ def train(model, optimizer, criterion, scheduler, scaler, ap, global_step, epoch
|
|||
epoch,
|
||||
OUT_PATH,
|
||||
model_losses=loss_dict,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None,
|
||||
)
|
||||
|
||||
# synthesize a full voice
|
||||
rand_idx = random.randrange(0, len(train_data))
|
||||
wav_path = train_data[rand_idx] if not isinstance(
|
||||
train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0]
|
||||
wav_path = (
|
||||
train_data[rand_idx] if not isinstance(train_data[rand_idx], (tuple, list)) else train_data[rand_idx][0]
|
||||
)
|
||||
wav = ap.load_wav(wav_path)
|
||||
ground_mel = ap.melspectrogram(wav)
|
||||
ground_mel = torch.FloatTensor(ground_mel)
|
||||
if use_cuda:
|
||||
ground_mel = ground_mel.cuda(non_blocking=True)
|
||||
sample_wav = model.inference(ground_mel,
|
||||
c.batched,
|
||||
c.target_samples,
|
||||
c.overlap_samples,
|
||||
)
|
||||
sample_wav = model.inference(
|
||||
ground_mel,
|
||||
c.batched,
|
||||
c.target_samples,
|
||||
c.overlap_samples,
|
||||
)
|
||||
predict_mel = ap.melspectrogram(sample_wav)
|
||||
|
||||
# compute spectrograms
|
||||
figures = {"train/ground_truth": plot_spectrogram(ground_mel.T),
|
||||
"train/prediction": plot_spectrogram(predict_mel.T)
|
||||
}
|
||||
figures = {
|
||||
"train/ground_truth": plot_spectrogram(ground_mel.T),
|
||||
"train/prediction": plot_spectrogram(predict_mel.T),
|
||||
}
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
tb_logger.tb_train_audios(
|
||||
global_step, {
|
||||
"train/audio": sample_wav}, c.audio["sample_rate"]
|
||||
)
|
||||
tb_logger.tb_train_audios(global_step, {"train/audio": sample_wav}, c.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
# print epoch stats
|
||||
|
@ -277,36 +273,32 @@ def evaluate(model, criterion, ap, global_step, epoch):
|
|||
|
||||
# print eval stats
|
||||
if c.print_eval:
|
||||
c_logger.print_eval_step(
|
||||
num_iter, loss_dict, keep_avg.avg_values)
|
||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if epoch % c.test_every_epochs == 0 and epoch != 0:
|
||||
# synthesize a full voice
|
||||
rand_idx = random.randrange(0, len(eval_data))
|
||||
wav_path = eval_data[rand_idx] if not isinstance(
|
||||
eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0]
|
||||
wav_path = eval_data[rand_idx] if not isinstance(eval_data[rand_idx], (tuple, list)) else eval_data[rand_idx][0]
|
||||
wav = ap.load_wav(wav_path)
|
||||
ground_mel = ap.melspectrogram(wav)
|
||||
ground_mel = torch.FloatTensor(ground_mel)
|
||||
if use_cuda:
|
||||
ground_mel = ground_mel.cuda(non_blocking=True)
|
||||
sample_wav = model.inference(ground_mel,
|
||||
c.batched,
|
||||
c.target_samples,
|
||||
c.overlap_samples,
|
||||
)
|
||||
sample_wav = model.inference(
|
||||
ground_mel,
|
||||
c.batched,
|
||||
c.target_samples,
|
||||
c.overlap_samples,
|
||||
)
|
||||
predict_mel = ap.melspectrogram(sample_wav)
|
||||
|
||||
# Sample audio
|
||||
tb_logger.tb_eval_audios(
|
||||
global_step, {
|
||||
"eval/audio": sample_wav}, c.audio["sample_rate"]
|
||||
)
|
||||
tb_logger.tb_eval_audios(global_step, {"eval/audio": sample_wav}, c.audio["sample_rate"])
|
||||
|
||||
# compute spectrograms
|
||||
figures = {
|
||||
"eval/ground_truth": plot_spectrogram(ground_mel.T),
|
||||
"eval/prediction": plot_spectrogram(predict_mel.T)
|
||||
"eval/prediction": plot_spectrogram(predict_mel.T),
|
||||
}
|
||||
tb_logger.tb_eval_figures(global_step, figures)
|
||||
|
||||
|
@ -347,11 +339,9 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
print(f" > Loading wavs from: {c.data_path}")
|
||||
if c.feature_path is not None:
|
||||
print(f" > Loading features from: {c.feature_path}")
|
||||
eval_data, train_data = load_wav_feat_data(
|
||||
c.data_path, c.feature_path, c.eval_split_size)
|
||||
eval_data, train_data = load_wav_feat_data(c.data_path, c.feature_path, c.eval_split_size)
|
||||
else:
|
||||
eval_data, train_data = load_wav_data(
|
||||
c.data_path, c.eval_split_size)
|
||||
eval_data, train_data = load_wav_data(c.data_path, c.eval_split_size)
|
||||
# setup model
|
||||
model_wavernn = setup_generator(c)
|
||||
|
||||
|
@ -404,8 +394,7 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
model_dict = set_init_dict(model_dict, checkpoint["model"], c)
|
||||
model_wavernn.load_state_dict(model_dict)
|
||||
|
||||
print(" > Model restored from step %d" %
|
||||
checkpoint["step"], flush=True)
|
||||
print(" > Model restored from step %d" % checkpoint["step"], flush=True)
|
||||
args.restore_step = checkpoint["step"]
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
@ -418,24 +407,20 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
print(" > Model has {} parameters".format(num_parameters), flush=True)
|
||||
|
||||
if args.restore_step == 0 or not args.best_path:
|
||||
best_loss = float('inf')
|
||||
best_loss = float("inf")
|
||||
print(" > Starting with inf best loss.")
|
||||
else:
|
||||
print(" > Restoring best loss from "
|
||||
f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path,
|
||||
map_location='cpu')['model_loss']
|
||||
print(" > Restoring best loss from " f"{os.path.basename(args.best_path)} ...")
|
||||
best_loss = torch.load(args.best_path, map_location="cpu")["model_loss"]
|
||||
print(f" > Starting with loaded last best loss {best_loss}.")
|
||||
keep_all_best = c.get('keep_all_best', False)
|
||||
keep_after = c.get('keep_after', 10000) # void if keep_all_best False
|
||||
keep_all_best = c.get("keep_all_best", False)
|
||||
keep_after = c.get("keep_after", 10000) # void if keep_all_best False
|
||||
|
||||
global_step = args.restore_step
|
||||
for epoch in range(0, c.epochs):
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
_, global_step = train(model_wavernn, optimizer,
|
||||
criterion, scheduler, scaler, ap, global_step, epoch)
|
||||
eval_avg_loss_dict = evaluate(
|
||||
model_wavernn, criterion, ap, global_step, epoch)
|
||||
_, global_step = train(model_wavernn, optimizer, criterion, scheduler, scaler, ap, global_step, epoch)
|
||||
eval_avg_loss_dict = evaluate(model_wavernn, criterion, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = eval_avg_loss_dict["avg_model_loss"]
|
||||
best_loss = save_best_model(
|
||||
|
@ -453,14 +438,13 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
keep_all_best=keep_all_best,
|
||||
keep_after=keep_after,
|
||||
model_losses=eval_avg_loss_dict,
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None
|
||||
scaler=scaler.state_dict() if c.mixed_precision else None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments(sys.argv)
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(
|
||||
args, model_class='vocoder')
|
||||
c, OUT_PATH, AUDIO_PATH, c_logger, tb_logger = process_args(args, model_class="vocoder")
|
||||
|
||||
try:
|
||||
main(args)
|
||||
|
|
|
@ -13,14 +13,21 @@ from TTS.vocoder.datasets.wavegrad_dataset import WaveGradDataset
|
|||
from TTS.vocoder.utils.generic_utils import setup_generator
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_path', type=str, help='Path to model checkpoint.')
|
||||
parser.add_argument('--config_path', type=str, help='Path to model config file.')
|
||||
parser.add_argument('--data_path', type=str, help='Path to data directory.')
|
||||
parser.add_argument('--output_path', type=str, help='path for output file including file name and extension.')
|
||||
parser.add_argument('--num_iter', type=int, help='Number of model inference iterations that you like to optimize noise schedule for.')
|
||||
parser.add_argument('--use_cuda', type=bool, help='enable/disable CUDA.')
|
||||
parser.add_argument('--num_samples', type=int, default=1, help='Number of datasamples used for inference.')
|
||||
parser.add_argument('--search_depth', type=int, default=3, help='Search granularity. Increasing this increases the run-time exponentially.')
|
||||
parser.add_argument("--model_path", type=str, help="Path to model checkpoint.")
|
||||
parser.add_argument("--config_path", type=str, help="Path to model config file.")
|
||||
parser.add_argument("--data_path", type=str, help="Path to data directory.")
|
||||
parser.add_argument("--output_path", type=str, help="path for output file including file name and extension.")
|
||||
parser.add_argument(
|
||||
"--num_iter", type=int, help="Number of model inference iterations that you like to optimize noise schedule for."
|
||||
)
|
||||
parser.add_argument("--use_cuda", type=bool, help="enable/disable CUDA.")
|
||||
parser.add_argument("--num_samples", type=int, default=1, help="Number of datasamples used for inference.")
|
||||
parser.add_argument(
|
||||
"--search_depth",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Search granularity. Increasing this increases the run-time exponentially.",
|
||||
)
|
||||
|
||||
# load config
|
||||
args = parser.parse_args()
|
||||
|
@ -31,18 +38,20 @@ ap = AudioProcessor(**config.audio)
|
|||
|
||||
# load dataset
|
||||
_, train_data = load_wav_data(args.data_path, 0)
|
||||
train_data = train_data[:args.num_samples]
|
||||
dataset = WaveGradDataset(ap=ap,
|
||||
items=train_data,
|
||||
seq_len=-1,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=config.pad_short,
|
||||
conv_pad=config.conv_pad,
|
||||
is_training=True,
|
||||
return_segments=False,
|
||||
use_noise_augment=False,
|
||||
use_cache=False,
|
||||
verbose=True)
|
||||
train_data = train_data[: args.num_samples]
|
||||
dataset = WaveGradDataset(
|
||||
ap=ap,
|
||||
items=train_data,
|
||||
seq_len=-1,
|
||||
hop_len=ap.hop_length,
|
||||
pad_short=config.pad_short,
|
||||
conv_pad=config.conv_pad,
|
||||
is_training=True,
|
||||
return_segments=False,
|
||||
use_noise_augment=False,
|
||||
use_cache=False,
|
||||
verbose=True,
|
||||
)
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
|
@ -50,7 +59,8 @@ loader = DataLoader(
|
|||
collate_fn=dataset.collate_full_clips,
|
||||
drop_last=False,
|
||||
num_workers=config.num_loader_workers,
|
||||
pin_memory=False)
|
||||
pin_memory=False,
|
||||
)
|
||||
|
||||
# setup the model
|
||||
model = setup_generator(config)
|
||||
|
@ -61,9 +71,9 @@ if args.use_cuda:
|
|||
base_values = sorted(10 * np.random.uniform(size=args.search_depth))
|
||||
print(base_values)
|
||||
exponents = 10 ** np.linspace(-6, -1, num=args.num_iter)
|
||||
best_error = float('inf')
|
||||
best_error = float("inf")
|
||||
best_schedule = None
|
||||
total_search_iter = len(base_values)**args.num_iter
|
||||
total_search_iter = len(base_values) ** args.num_iter
|
||||
for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=total_search_iter):
|
||||
beta = exponents * base
|
||||
model.compute_noise_level(beta)
|
||||
|
@ -84,6 +94,6 @@ for base in tqdm(cartesian_product(base_values, repeat=args.num_iter), total=tot
|
|||
mse = torch.sum((mel - mel_hat) ** 2).mean()
|
||||
if mse.item() < best_error:
|
||||
best_error = mse.item()
|
||||
best_schedule = {'beta': beta}
|
||||
best_schedule = {"beta": beta}
|
||||
print(f" > Found a better schedule. - MSE: {mse.item()}")
|
||||
np.save(args.output_path, best_schedule)
|
||||
|
|
|
@ -13,23 +13,40 @@ from TTS.utils.io import load_config
|
|||
|
||||
def create_argparser():
|
||||
def convert_boolean(x):
|
||||
return x.lower() in ['true', '1', 'yes']
|
||||
return x.lower() in ["true", "1", "yes"]
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--list_models', type=convert_boolean, nargs='?', const=True, default=False, help='list available pre-trained tts and vocoder models.')
|
||||
parser.add_argument('--model_name', type=str, default="tts_models/en/ljspeech/speedy-speech-wn", help='name of one of the released tts models.')
|
||||
parser.add_argument('--vocoder_name', type=str, default=None, help='name of one of the released vocoder models.')
|
||||
parser.add_argument('--tts_checkpoint', type=str, help='path to custom tts checkpoint file')
|
||||
parser.add_argument('--tts_config', type=str, help='path to custom tts config.json file')
|
||||
parser.add_argument('--tts_speakers', type=str, help='path to JSON file containing speaker ids, if speaker ids are used in the model')
|
||||
parser.add_argument('--vocoder_config', type=str, default=None, help='path to vocoder config file.')
|
||||
parser.add_argument('--vocoder_checkpoint', type=str, default=None, help='path to vocoder checkpoint file.')
|
||||
parser.add_argument('--port', type=int, default=5002, help='port to listen on.')
|
||||
parser.add_argument('--use_cuda', type=convert_boolean, default=False, help='true to use CUDA.')
|
||||
parser.add_argument('--debug', type=convert_boolean, default=False, help='true to enable Flask debug mode.')
|
||||
parser.add_argument('--show_details', type=convert_boolean, default=False, help='Generate model detail page.')
|
||||
parser.add_argument(
|
||||
"--list_models",
|
||||
type=convert_boolean,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
help="list available pre-trained tts and vocoder models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name",
|
||||
type=str,
|
||||
default="tts_models/en/ljspeech/speedy-speech-wn",
|
||||
help="name of one of the released tts models.",
|
||||
)
|
||||
parser.add_argument("--vocoder_name", type=str, default=None, help="name of one of the released vocoder models.")
|
||||
parser.add_argument("--tts_checkpoint", type=str, help="path to custom tts checkpoint file")
|
||||
parser.add_argument("--tts_config", type=str, help="path to custom tts config.json file")
|
||||
parser.add_argument(
|
||||
"--tts_speakers",
|
||||
type=str,
|
||||
help="path to JSON file containing speaker ids, if speaker ids are used in the model",
|
||||
)
|
||||
parser.add_argument("--vocoder_config", type=str, default=None, help="path to vocoder config file.")
|
||||
parser.add_argument("--vocoder_checkpoint", type=str, default=None, help="path to vocoder checkpoint file.")
|
||||
parser.add_argument("--port", type=int, default=5002, help="port to listen on.")
|
||||
parser.add_argument("--use_cuda", type=convert_boolean, default=False, help="true to use CUDA.")
|
||||
parser.add_argument("--debug", type=convert_boolean, default=False, help="true to enable Flask debug mode.")
|
||||
parser.add_argument("--show_details", type=convert_boolean, default=False, help="Generate model detail page.")
|
||||
return parser
|
||||
|
||||
|
||||
# parse the args
|
||||
args = create_argparser().parse_args()
|
||||
|
||||
|
@ -43,7 +60,7 @@ if args.list_models:
|
|||
# update in-use models to the specified released models.
|
||||
if args.model_name is not None:
|
||||
tts_checkpoint_file, tts_config_file, tts_json_dict = manager.download_model(args.model_name)
|
||||
args.vocoder_name = tts_json_dict['default_vocoder'] if args.vocoder_name is None else args.vocoder_name
|
||||
args.vocoder_name = tts_json_dict["default_vocoder"] if args.vocoder_name is None else args.vocoder_name
|
||||
|
||||
if args.vocoder_name is not None:
|
||||
vocoder_checkpoint_file, vocoder_config_file, vocoder_json_dict = manager.download_model(args.vocoder_name)
|
||||
|
@ -59,16 +76,19 @@ if not args.vocoder_checkpoint and os.path.isfile(vocoder_checkpoint_file):
|
|||
if not args.vocoder_config and os.path.isfile(vocoder_config_file):
|
||||
args.vocoder_config = vocoder_config_file
|
||||
|
||||
synthesizer = Synthesizer(args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda)
|
||||
synthesizer = Synthesizer(
|
||||
args.tts_checkpoint, args.tts_config, args.vocoder_checkpoint, args.vocoder_config, args.use_cuda
|
||||
)
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route('/')
|
||||
@app.route("/")
|
||||
def index():
|
||||
return render_template('index.html', show_details=args.show_details)
|
||||
return render_template("index.html", show_details=args.show_details)
|
||||
|
||||
@app.route('/details')
|
||||
|
||||
@app.route("/details")
|
||||
def details():
|
||||
model_config = load_config(args.tts_config)
|
||||
if args.vocoder_config is not None and os.path.isfile(args.vocoder_config):
|
||||
|
@ -76,26 +96,28 @@ def details():
|
|||
else:
|
||||
vocoder_config = None
|
||||
|
||||
return render_template('details.html',
|
||||
show_details=args.show_details
|
||||
, model_config=model_config
|
||||
, vocoder_config=vocoder_config
|
||||
, args=args.__dict__
|
||||
)
|
||||
return render_template(
|
||||
"details.html",
|
||||
show_details=args.show_details,
|
||||
model_config=model_config,
|
||||
vocoder_config=vocoder_config,
|
||||
args=args.__dict__,
|
||||
)
|
||||
|
||||
@app.route('/api/tts', methods=['GET'])
|
||||
|
||||
@app.route("/api/tts", methods=["GET"])
|
||||
def tts():
|
||||
text = request.args.get('text')
|
||||
text = request.args.get("text")
|
||||
print(" > Model input: {}".format(text))
|
||||
wavs = synthesizer.tts(text)
|
||||
out = io.BytesIO()
|
||||
synthesizer.save_wav(wavs, out)
|
||||
return send_file(out, mimetype='audio/wav')
|
||||
return send_file(out, mimetype="audio/wav")
|
||||
|
||||
|
||||
def main():
|
||||
app.run(debug=args.debug, host='0.0.0.0', port=args.port)
|
||||
app.run(debug=args.debug, host="0.0.0.0", port=args.port)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -7,9 +7,19 @@ from torch.utils.data import Dataset
|
|||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
def __init__(self, ap, meta_data, voice_len=1.6, num_speakers_in_batch=64,
|
||||
storage_size=1, sample_from_storage_p=0.5, additive_noise=0,
|
||||
num_utter_per_speaker=10, skip_speakers=False, verbose=False):
|
||||
def __init__(
|
||||
self,
|
||||
ap,
|
||||
meta_data,
|
||||
voice_len=1.6,
|
||||
num_speakers_in_batch=64,
|
||||
storage_size=1,
|
||||
sample_from_storage_p=0.5,
|
||||
additive_noise=0,
|
||||
num_utter_per_speaker=10,
|
||||
skip_speakers=False,
|
||||
verbose=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
ap (TTS.tts.utils.AudioProcessor): audio processor object.
|
||||
|
@ -28,7 +38,7 @@ class MyDataset(Dataset):
|
|||
self.ap = ap
|
||||
self.verbose = verbose
|
||||
self.__parse_items()
|
||||
self.storage = queue.Queue(maxsize=storage_size*num_speakers_in_batch)
|
||||
self.storage = queue.Queue(maxsize=storage_size * num_speakers_in_batch)
|
||||
self.sample_from_storage_p = float(sample_from_storage_p)
|
||||
self.additive_noise = float(additive_noise)
|
||||
if self.verbose:
|
||||
|
@ -69,11 +79,14 @@ class MyDataset(Dataset):
|
|||
if speaker_ in self.speaker_to_utters.keys():
|
||||
self.speaker_to_utters[speaker_].append(path_)
|
||||
else:
|
||||
self.speaker_to_utters[speaker_] = [path_, ]
|
||||
self.speaker_to_utters[speaker_] = [
|
||||
path_,
|
||||
]
|
||||
|
||||
if self.skip_speakers:
|
||||
self.speaker_to_utters = {k: v for (k, v) in self.speaker_to_utters.items() if
|
||||
len(v) >= self.num_utter_per_speaker}
|
||||
self.speaker_to_utters = {
|
||||
k: v for (k, v) in self.speaker_to_utters.items() if len(v) >= self.num_utter_per_speaker
|
||||
}
|
||||
|
||||
self.speakers = [k for (k, v) in self.speaker_to_utters.items()]
|
||||
|
||||
|
@ -100,13 +113,9 @@ class MyDataset(Dataset):
|
|||
def __sample_speaker(self):
|
||||
speaker = random.sample(self.speakers, 1)[0]
|
||||
if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]):
|
||||
utters = random.choices(
|
||||
self.speaker_to_utters[speaker], k=self.num_utter_per_speaker
|
||||
)
|
||||
utters = random.choices(self.speaker_to_utters[speaker], k=self.num_utter_per_speaker)
|
||||
else:
|
||||
utters = random.sample(
|
||||
self.speaker_to_utters[speaker], self.num_utter_per_speaker
|
||||
)
|
||||
utters = random.sample(self.speaker_to_utters[speaker], self.num_utter_per_speaker)
|
||||
return speaker, utters
|
||||
|
||||
def __sample_speaker_utterances(self, speaker):
|
||||
|
@ -160,7 +169,9 @@ class MyDataset(Dataset):
|
|||
|
||||
# get a random subset of each of the wavs and convert to MFCC.
|
||||
offsets_ = [random.randint(0, wav.shape[0] - self.seq_len) for wav in wavs_]
|
||||
mels_ = [self.ap.melspectrogram(wavs_[i][offsets_[i]: offsets_[i] + self.seq_len]) for i in range(len(wavs_))]
|
||||
mels_ = [
|
||||
self.ap.melspectrogram(wavs_[i][offsets_[i] : offsets_[i] + self.seq_len]) for i in range(len(wavs_))
|
||||
]
|
||||
feats_ = [torch.FloatTensor(mel) for mel in mels_]
|
||||
|
||||
labels.append(labels_)
|
||||
|
|
|
@ -16,14 +16,14 @@ class GE2ELoss(nn.Module):
|
|||
- init_w (float): defines the initial value of w in Equation (5) of [1]
|
||||
- init_b (float): definies the initial value of b in Equation (5) of [1]
|
||||
"""
|
||||
super(GE2ELoss, self).__init__()
|
||||
super().__init__()
|
||||
# pylint: disable=E1102
|
||||
self.w = nn.Parameter(torch.tensor(init_w))
|
||||
# pylint: disable=E1102
|
||||
self.b = nn.Parameter(torch.tensor(init_b))
|
||||
self.loss_method = loss_method
|
||||
|
||||
print(' > Initialised Generalized End-to-End loss')
|
||||
print(" > Initialised Generalized End-to-End loss")
|
||||
|
||||
assert self.loss_method in ["softmax", "contrast"]
|
||||
|
||||
|
@ -55,9 +55,7 @@ class GE2ELoss(nn.Module):
|
|||
for spkr_idx, speaker in enumerate(dvecs):
|
||||
cs_row = []
|
||||
for utt_idx, utterance in enumerate(speaker):
|
||||
new_centroids = self.calc_new_centroids(
|
||||
dvecs, centroids, spkr_idx, utt_idx
|
||||
)
|
||||
new_centroids = self.calc_new_centroids(dvecs, centroids, spkr_idx, utt_idx)
|
||||
# vector based cosine similarity for speed
|
||||
cs_row.append(
|
||||
torch.clamp(
|
||||
|
@ -99,14 +97,8 @@ class GE2ELoss(nn.Module):
|
|||
L_row = []
|
||||
for i in range(M):
|
||||
centroids_sigmoids = torch.sigmoid(cos_sim_matrix[j, i])
|
||||
excl_centroids_sigmoids = torch.cat(
|
||||
(centroids_sigmoids[:j], centroids_sigmoids[j + 1 :])
|
||||
)
|
||||
L_row.append(
|
||||
1.0
|
||||
- torch.sigmoid(cos_sim_matrix[j, i, j])
|
||||
+ torch.max(excl_centroids_sigmoids)
|
||||
)
|
||||
excl_centroids_sigmoids = torch.cat((centroids_sigmoids[:j], centroids_sigmoids[j + 1 :]))
|
||||
L_row.append(1.0 - torch.sigmoid(cos_sim_matrix[j, i, j]) + torch.max(excl_centroids_sigmoids))
|
||||
L_row = torch.stack(L_row)
|
||||
L.append(L_row)
|
||||
return torch.stack(L)
|
||||
|
@ -122,6 +114,7 @@ class GE2ELoss(nn.Module):
|
|||
L = self.embed_loss(dvecs, cos_sim_matrix)
|
||||
return L.mean()
|
||||
|
||||
|
||||
# adapted from https://github.com/clovaai/voxceleb_trainer/blob/master/loss/angleproto.py
|
||||
class AngleProtoLoss(nn.Module):
|
||||
"""
|
||||
|
@ -134,15 +127,16 @@ class AngleProtoLoss(nn.Module):
|
|||
- init_w (float): defines the initial value of w
|
||||
- init_b (float): definies the initial value of b
|
||||
"""
|
||||
|
||||
def __init__(self, init_w=10.0, init_b=-5.0):
|
||||
super(AngleProtoLoss, self).__init__()
|
||||
super().__init__()
|
||||
# pylint: disable=E1102
|
||||
self.w = nn.Parameter(torch.tensor(init_w))
|
||||
# pylint: disable=E1102
|
||||
self.b = nn.Parameter(torch.tensor(init_b))
|
||||
self.criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
print(' > Initialised Angular Prototypical loss')
|
||||
print(" > Initialised Angular Prototypical loss")
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
|
@ -152,7 +146,10 @@ class AngleProtoLoss(nn.Module):
|
|||
out_positive = x[:, 0, :]
|
||||
num_speakers = out_anchor.size()[0]
|
||||
|
||||
cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1).expand(-1, -1, num_speakers), out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2))
|
||||
cos_sim_matrix = F.cosine_similarity(
|
||||
out_positive.unsqueeze(-1).expand(-1, -1, num_speakers),
|
||||
out_anchor.unsqueeze(-1).expand(-1, -1, num_speakers).transpose(0, 2),
|
||||
)
|
||||
torch.clamp(self.w, 1e-6)
|
||||
cos_sim_matrix = cos_sim_matrix * self.w + self.b
|
||||
label = torch.arange(num_speakers).to(cos_sim_matrix.device)
|
||||
|
|
|
@ -16,19 +16,19 @@ class LSTMWithProjection(nn.Module):
|
|||
o, (_, _) = self.lstm(x)
|
||||
return self.linear(o)
|
||||
|
||||
|
||||
class LSTMWithoutProjection(nn.Module):
|
||||
def __init__(self, input_dim, lstm_dim, proj_dim, num_lstm_layers):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(input_size=input_dim,
|
||||
hidden_size=lstm_dim,
|
||||
num_layers=num_lstm_layers,
|
||||
batch_first=True)
|
||||
self.lstm = nn.LSTM(input_size=input_dim, hidden_size=lstm_dim, num_layers=num_lstm_layers, batch_first=True)
|
||||
self.linear = nn.Linear(lstm_dim, proj_dim, bias=True)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
_, (hidden, _) = self.lstm(x)
|
||||
return self.relu(self.linear(hidden[-1]))
|
||||
|
||||
|
||||
class SpeakerEncoder(nn.Module):
|
||||
def __init__(self, input_dim, proj_dim=256, lstm_dim=768, num_lstm_layers=3, use_lstm_with_projection=True):
|
||||
super().__init__()
|
||||
|
@ -106,7 +106,5 @@ class SpeakerEncoder(nn.Module):
|
|||
if embed is None:
|
||||
embed = self.inference(frames)
|
||||
else:
|
||||
embed[cur_iter <= num_iters, :] += self.inference(
|
||||
frames[cur_iter <= num_iters, :, :]
|
||||
)
|
||||
embed[cur_iter <= num_iters, :] += self.inference(frames[cur_iter <= num_iters, :, :])
|
||||
return embed / num_iters
|
||||
|
|
|
@ -9,108 +9,114 @@ from TTS.utils.generic_utils import check_argument
|
|||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
return re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
|
||||
return re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
|
||||
|
||||
def setup_model(c):
|
||||
model = SpeakerEncoder(c.model['input_dim'], c.model['proj_dim'],
|
||||
c.model['lstm_dim'], c.model['num_lstm_layers'])
|
||||
model = SpeakerEncoder(c.model["input_dim"], c.model["proj_dim"], c.model["lstm_dim"], c.model["num_lstm_layers"])
|
||||
return model
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, model_loss, out_path,
|
||||
current_step, epoch):
|
||||
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||
def save_checkpoint(model, optimizer, model_loss, out_path, current_step, epoch):
|
||||
checkpoint_path = "checkpoint_{}.pth.tar".format(current_step)
|
||||
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
||||
print(" | | > Checkpoint saving : {}".format(checkpoint_path))
|
||||
|
||||
new_state_dict = model.state_dict()
|
||||
state = {
|
||||
'model': new_state_dict,
|
||||
'optimizer': optimizer.state_dict() if optimizer is not None else None,
|
||||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
'loss': model_loss,
|
||||
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||
"model": new_state_dict,
|
||||
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
||||
"step": current_step,
|
||||
"epoch": epoch,
|
||||
"loss": model_loss,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
torch.save(state, checkpoint_path)
|
||||
|
||||
|
||||
def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
||||
current_step):
|
||||
def save_best_model(model, optimizer, model_loss, best_loss, out_path, current_step):
|
||||
if model_loss < best_loss:
|
||||
new_state_dict = model.state_dict()
|
||||
state = {
|
||||
'model': new_state_dict,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'step': current_step,
|
||||
'loss': model_loss,
|
||||
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||
"model": new_state_dict,
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"step": current_step,
|
||||
"loss": model_loss,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
}
|
||||
best_loss = model_loss
|
||||
bestmodel_path = 'best_model.pth.tar'
|
||||
bestmodel_path = "best_model.pth.tar"
|
||||
bestmodel_path = os.path.join(out_path, bestmodel_path)
|
||||
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(
|
||||
model_loss, bestmodel_path))
|
||||
print("\n > BEST MODEL ({0:.5f}) : {1:}".format(model_loss, bestmodel_path))
|
||||
torch.save(state, bestmodel_path)
|
||||
return best_loss
|
||||
|
||||
|
||||
def check_config_speaker_encoder(c):
|
||||
"""Check the config.json file of the speaker encoder"""
|
||||
check_argument('run_name', c, restricted=True, val_type=str)
|
||||
check_argument('run_description', c, val_type=str)
|
||||
check_argument("run_name", c, restricted=True, val_type=str)
|
||||
check_argument("run_description", c, val_type=str)
|
||||
|
||||
# audio processing parameters
|
||||
check_argument('audio', c, restricted=True, val_type=dict)
|
||||
check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
|
||||
check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
|
||||
check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
|
||||
check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
|
||||
check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
|
||||
check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
|
||||
check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
|
||||
check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
|
||||
check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
|
||||
check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
|
||||
check_argument("audio", c, restricted=True, val_type=dict)
|
||||
check_argument("num_mels", c["audio"], restricted=True, val_type=int, min_val=10, max_val=2056)
|
||||
check_argument("fft_size", c["audio"], restricted=True, val_type=int, min_val=128, max_val=4058)
|
||||
check_argument("sample_rate", c["audio"], restricted=True, val_type=int, min_val=512, max_val=100000)
|
||||
check_argument(
|
||||
"frame_length_ms",
|
||||
c["audio"],
|
||||
restricted=True,
|
||||
val_type=float,
|
||||
min_val=10,
|
||||
max_val=1000,
|
||||
alternative="win_length",
|
||||
)
|
||||
check_argument(
|
||||
"frame_shift_ms", c["audio"], restricted=True, val_type=float, min_val=1, max_val=1000, alternative="hop_length"
|
||||
)
|
||||
check_argument("preemphasis", c["audio"], restricted=True, val_type=float, min_val=0, max_val=1)
|
||||
check_argument("min_level_db", c["audio"], restricted=True, val_type=int, min_val=-1000, max_val=10)
|
||||
check_argument("ref_level_db", c["audio"], restricted=True, val_type=int, min_val=0, max_val=1000)
|
||||
check_argument("power", c["audio"], restricted=True, val_type=float, min_val=1, max_val=5)
|
||||
check_argument("griffin_lim_iters", c["audio"], restricted=True, val_type=int, min_val=10, max_val=1000)
|
||||
|
||||
# training parameters
|
||||
check_argument('loss', c, enum_list=['ge2e', 'angleproto'], restricted=True, val_type=str)
|
||||
check_argument('grad_clip', c, restricted=True, val_type=float)
|
||||
check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument('lr', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('lr_decay', c, restricted=True, val_type=bool)
|
||||
check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
|
||||
check_argument('num_speakers_in_batch', c, restricted=True, val_type=int)
|
||||
check_argument('num_loader_workers', c, restricted=True, val_type=int)
|
||||
check_argument('wd', c, restricted=True, val_type=float, min_val=0.0, max_val=1.0)
|
||||
check_argument("loss", c, enum_list=["ge2e", "angleproto"], restricted=True, val_type=str)
|
||||
check_argument("grad_clip", c, restricted=True, val_type=float)
|
||||
check_argument("epochs", c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument("lr", c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument("lr_decay", c, restricted=True, val_type=bool)
|
||||
check_argument("warmup_steps", c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument("tb_model_param_stats", c, restricted=True, val_type=bool)
|
||||
check_argument("num_speakers_in_batch", c, restricted=True, val_type=int)
|
||||
check_argument("num_loader_workers", c, restricted=True, val_type=int)
|
||||
check_argument("wd", c, restricted=True, val_type=float, min_val=0.0, max_val=1.0)
|
||||
|
||||
# checkpoint and output parameters
|
||||
check_argument('steps_plot_stats', c, restricted=True, val_type=int)
|
||||
check_argument('checkpoint', c, restricted=True, val_type=bool)
|
||||
check_argument('save_step', c, restricted=True, val_type=int)
|
||||
check_argument('print_step', c, restricted=True, val_type=int)
|
||||
check_argument('output_path', c, restricted=True, val_type=str)
|
||||
check_argument("steps_plot_stats", c, restricted=True, val_type=int)
|
||||
check_argument("checkpoint", c, restricted=True, val_type=bool)
|
||||
check_argument("save_step", c, restricted=True, val_type=int)
|
||||
check_argument("print_step", c, restricted=True, val_type=int)
|
||||
check_argument("output_path", c, restricted=True, val_type=str)
|
||||
|
||||
# model parameters
|
||||
check_argument('model', c, restricted=True, val_type=dict)
|
||||
check_argument('input_dim', c['model'], restricted=True, val_type=int)
|
||||
check_argument('proj_dim', c['model'], restricted=True, val_type=int)
|
||||
check_argument('lstm_dim', c['model'], restricted=True, val_type=int)
|
||||
check_argument('num_lstm_layers', c['model'], restricted=True, val_type=int)
|
||||
check_argument('use_lstm_with_projection', c['model'], restricted=True, val_type=bool)
|
||||
check_argument("model", c, restricted=True, val_type=dict)
|
||||
check_argument("input_dim", c["model"], restricted=True, val_type=int)
|
||||
check_argument("proj_dim", c["model"], restricted=True, val_type=int)
|
||||
check_argument("lstm_dim", c["model"], restricted=True, val_type=int)
|
||||
check_argument("num_lstm_layers", c["model"], restricted=True, val_type=int)
|
||||
check_argument("use_lstm_with_projection", c["model"], restricted=True, val_type=bool)
|
||||
|
||||
# in-memory storage parameters
|
||||
check_argument('storage', c, restricted=True, val_type=dict)
|
||||
check_argument('sample_from_storage_p', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0)
|
||||
check_argument('storage_size', c['storage'], restricted=True, val_type=int, min_val=1, max_val=100)
|
||||
check_argument('additive_noise', c['storage'], restricted=True, val_type=float, min_val=0.0, max_val=1.0)
|
||||
check_argument("storage", c, restricted=True, val_type=dict)
|
||||
check_argument("sample_from_storage_p", c["storage"], restricted=True, val_type=float, min_val=0.0, max_val=1.0)
|
||||
check_argument("storage_size", c["storage"], restricted=True, val_type=int, min_val=1, max_val=100)
|
||||
check_argument("additive_noise", c["storage"], restricted=True, val_type=float, min_val=0.0, max_val=1.0)
|
||||
|
||||
# datasets - checking only the first entry
|
||||
check_argument('datasets', c, restricted=True, val_type=list)
|
||||
for dataset_entry in c['datasets']:
|
||||
check_argument('name', dataset_entry, restricted=True, val_type=str)
|
||||
check_argument('path', dataset_entry, restricted=True, val_type=str)
|
||||
check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
|
||||
check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
||||
check_argument("datasets", c, restricted=True, val_type=list)
|
||||
for dataset_entry in c["datasets"]:
|
||||
check_argument("name", dataset_entry, restricted=True, val_type=str)
|
||||
check_argument("path", dataset_entry, restricted=True, val_type=str)
|
||||
check_argument("meta_file_train", dataset_entry, restricted=True, val_type=[str, list])
|
||||
check_argument("meta_file_val", dataset_entry, restricted=True, val_type=str)
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
# Only support eager mode and TF>=2.0.0
|
||||
# pylint: disable=no-member, invalid-name, relative-beyond-top-level
|
||||
# pylint: disable=too-many-locals, too-many-statements, too-many-arguments, too-many-instance-attributes
|
||||
''' voxceleb 1 & 2 '''
|
||||
""" voxceleb 1 & 2 """
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
@ -32,40 +32,38 @@ import soundfile as sf
|
|||
gfile = tf.compat.v1.gfile
|
||||
|
||||
SUBSETS = {
|
||||
"vox1_dev_wav":
|
||||
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad"],
|
||||
"vox1_test_wav":
|
||||
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"],
|
||||
"vox2_dev_aac":
|
||||
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah"],
|
||||
"vox2_test_aac":
|
||||
["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"]
|
||||
"vox1_dev_wav": [
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partaa",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partab",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partac",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_dev_wav_partad",
|
||||
],
|
||||
"vox1_test_wav": ["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox1_test_wav.zip"],
|
||||
"vox2_dev_aac": [
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaa",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partab",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partac",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partad",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partae",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partaf",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partag",
|
||||
"http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_dev_aac_partah",
|
||||
],
|
||||
"vox2_test_aac": ["http://www.robots.ox.ac.uk/~vgg/data/voxceleb/vox1a/vox2_test_aac.zip"],
|
||||
}
|
||||
|
||||
MD5SUM = {
|
||||
"vox1_dev_wav": "ae63e55b951748cc486645f532ba230b",
|
||||
"vox2_dev_aac": "bbc063c46078a602ca71605645c2a402",
|
||||
"vox1_test_wav": "185fdc63c3c739954633d50379a3d102",
|
||||
"vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312"
|
||||
"vox2_test_aac": "0d2b3ea430a821c33263b5ea37ede312",
|
||||
}
|
||||
|
||||
USER = {
|
||||
"user": "",
|
||||
"password": ""
|
||||
}
|
||||
USER = {"user": "", "password": ""}
|
||||
|
||||
speaker_id_dict = {}
|
||||
|
||||
|
||||
def download_and_extract(directory, subset, urls):
|
||||
"""Download and extract the given split of dataset.
|
||||
|
||||
|
@ -83,31 +81,30 @@ def download_and_extract(directory, subset, urls):
|
|||
if os.path.exists(zip_filepath):
|
||||
continue
|
||||
logging.info("Downloading %s to %s" % (url, zip_filepath))
|
||||
subprocess.call('wget %s --user %s --password %s -O %s' %
|
||||
(url, USER["user"], USER["password"], zip_filepath), shell=True)
|
||||
subprocess.call(
|
||||
"wget %s --user %s --password %s -O %s" % (url, USER["user"], USER["password"], zip_filepath),
|
||||
shell=True,
|
||||
)
|
||||
|
||||
statinfo = os.stat(zip_filepath)
|
||||
logging.info(
|
||||
"Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size)
|
||||
)
|
||||
logging.info("Successfully downloaded %s, size(bytes): %d" % (url, statinfo.st_size))
|
||||
|
||||
# concatenate all parts into zip files
|
||||
if ".zip" not in zip_filepath:
|
||||
zip_filepath = "_".join(zip_filepath.split("_")[:-1])
|
||||
subprocess.call('cat %s* > %s.zip' %
|
||||
(zip_filepath, zip_filepath), shell=True)
|
||||
subprocess.call("cat %s* > %s.zip" % (zip_filepath, zip_filepath), shell=True)
|
||||
zip_filepath += ".zip"
|
||||
extract_path = zip_filepath.strip(".zip")
|
||||
|
||||
# check zip file md5sum
|
||||
md5 = hashlib.md5(open(zip_filepath, 'rb').read()).hexdigest()
|
||||
md5 = hashlib.md5(open(zip_filepath, "rb").read()).hexdigest()
|
||||
if md5 != MD5SUM[subset]:
|
||||
raise ValueError("md5sum of %s mismatch" % zip_filepath)
|
||||
|
||||
with zipfile.ZipFile(zip_filepath, "r") as zfile:
|
||||
zfile.extractall(directory)
|
||||
extract_path_ori = os.path.join(directory, zfile.infolist()[0].filename)
|
||||
subprocess.call('mv %s %s' % (extract_path_ori, extract_path), shell=True)
|
||||
subprocess.call("mv %s %s" % (extract_path_ori, extract_path), shell=True)
|
||||
finally:
|
||||
# gfile.Remove(zip_filepath)
|
||||
pass
|
||||
|
@ -148,8 +145,7 @@ def decode_aac_with_ffmpeg(aac_file, wav_file):
|
|||
return True
|
||||
|
||||
|
||||
def convert_audio_and_make_label(input_dir, subset,
|
||||
output_dir, output_file):
|
||||
def convert_audio_and_make_label(input_dir, subset, output_dir, output_file):
|
||||
"""Optionally convert AAC to WAV and make speaker labels.
|
||||
Args:
|
||||
input_dir: the directory which holds the input dataset.
|
||||
|
@ -167,7 +163,7 @@ def convert_audio_and_make_label(input_dir, subset,
|
|||
for filename in filenames:
|
||||
name, ext = os.path.splitext(filename)
|
||||
if ext.lower() == ".wav":
|
||||
_, ext2 = (os.path.splitext(name))
|
||||
_, ext2 = os.path.splitext(name)
|
||||
if ext2:
|
||||
continue
|
||||
wav_file = os.path.join(root, filename)
|
||||
|
@ -186,15 +182,12 @@ def convert_audio_and_make_label(input_dir, subset,
|
|||
speaker_id_dict[speaker_name] = num
|
||||
# wav_filesize = os.path.getsize(wav_file)
|
||||
wav_length = len(sf.read(wav_file)[0])
|
||||
files.append(
|
||||
(os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name)
|
||||
)
|
||||
files.append((os.path.abspath(wav_file), wav_length, speaker_id_dict[speaker_name], speaker_name))
|
||||
|
||||
# Write to CSV file which contains four columns:
|
||||
# "wav_filename", "wav_length_ms", "speaker_id", "speaker_name".
|
||||
csv_file_path = os.path.join(output_dir, output_file)
|
||||
df = pandas.DataFrame(
|
||||
data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"])
|
||||
df = pandas.DataFrame(data=files, columns=["wav_filename", "wav_length_ms", "speaker_id", "speaker_name"])
|
||||
df.to_csv(csv_file_path, index=False, sep="\t")
|
||||
logging.info("Successfully generated csv file {}".format(csv_file_path))
|
||||
|
||||
|
@ -205,19 +198,14 @@ def processor(directory, subset, force_process):
|
|||
if subset not in urls:
|
||||
raise ValueError(subset, "is not in voxceleb")
|
||||
|
||||
subset_csv = os.path.join(directory, subset + '.csv')
|
||||
subset_csv = os.path.join(directory, subset + ".csv")
|
||||
if not force_process and os.path.exists(subset_csv):
|
||||
return subset_csv
|
||||
|
||||
logging.info("Downloading and process the voxceleb in %s", directory)
|
||||
logging.info("Preparing subset %s", subset)
|
||||
download_and_extract(directory, subset, urls[subset])
|
||||
convert_audio_and_make_label(
|
||||
directory,
|
||||
subset,
|
||||
directory,
|
||||
subset + ".csv"
|
||||
)
|
||||
convert_audio_and_make_label(directory, subset, directory, subset + ".csv")
|
||||
logging.info("Finished downloading and processing")
|
||||
return subset_csv
|
||||
|
||||
|
|
|
@ -7,31 +7,31 @@ import numpy as np
|
|||
import torch
|
||||
import tqdm
|
||||
from torch.utils.data import Dataset
|
||||
from TTS.tts.utils.data import (prepare_data, prepare_stop_target,
|
||||
prepare_tensor)
|
||||
from TTS.tts.utils.text import (pad_with_eos_bos, phoneme_to_sequence,
|
||||
text_to_sequence)
|
||||
from TTS.tts.utils.data import prepare_data, prepare_stop_target, prepare_tensor
|
||||
from TTS.tts.utils.text import pad_with_eos_bos, phoneme_to_sequence, text_to_sequence
|
||||
|
||||
|
||||
class MyDataset(Dataset):
|
||||
def __init__(self,
|
||||
outputs_per_step,
|
||||
text_cleaner,
|
||||
compute_linear_spec,
|
||||
ap,
|
||||
meta_data,
|
||||
tp=None,
|
||||
add_blank=False,
|
||||
batch_group_size=0,
|
||||
min_seq_len=0,
|
||||
max_seq_len=float("inf"),
|
||||
use_phonemes=True,
|
||||
phoneme_cache_path=None,
|
||||
phoneme_language="en-us",
|
||||
enable_eos_bos=False,
|
||||
speaker_mapping=None,
|
||||
use_noise_augment=False,
|
||||
verbose=False):
|
||||
def __init__(
|
||||
self,
|
||||
outputs_per_step,
|
||||
text_cleaner,
|
||||
compute_linear_spec,
|
||||
ap,
|
||||
meta_data,
|
||||
tp=None,
|
||||
add_blank=False,
|
||||
batch_group_size=0,
|
||||
min_seq_len=0,
|
||||
max_seq_len=float("inf"),
|
||||
use_phonemes=True,
|
||||
phoneme_cache_path=None,
|
||||
phoneme_language="en-us",
|
||||
enable_eos_bos=False,
|
||||
speaker_mapping=None,
|
||||
use_noise_augment=False,
|
||||
verbose=False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
outputs_per_step (int): number of time frames predicted per step.
|
||||
|
@ -53,7 +53,7 @@ class MyDataset(Dataset):
|
|||
use_noise_augment (bool): enable adding random noise to wav for augmentation.
|
||||
verbose (bool): print diagnostic information.
|
||||
"""
|
||||
super(MyDataset, self).__init__()
|
||||
super().__init__()
|
||||
self.batch_group_size = batch_group_size
|
||||
self.items = meta_data
|
||||
self.outputs_per_step = outputs_per_step
|
||||
|
@ -88,45 +88,42 @@ class MyDataset(Dataset):
|
|||
|
||||
@staticmethod
|
||||
def load_np(filename):
|
||||
data = np.load(filename).astype('float32')
|
||||
data = np.load(filename).astype("float32")
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners,
|
||||
language, tp, add_blank):
|
||||
def _generate_and_cache_phoneme_sequence(text, cache_path, cleaners, language, tp, add_blank):
|
||||
"""generate a phoneme sequence from text.
|
||||
since the usage is for subsequent caching, we never add bos and
|
||||
eos chars here. Instead we add those dynamically later; based on the
|
||||
config option."""
|
||||
phonemes = phoneme_to_sequence(text, [cleaners],
|
||||
language=language,
|
||||
enable_eos_bos=False,
|
||||
tp=tp,
|
||||
add_blank=add_blank)
|
||||
phonemes = phoneme_to_sequence(
|
||||
text, [cleaners], language=language, enable_eos_bos=False, tp=tp, add_blank=add_blank
|
||||
)
|
||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
||||
np.save(cache_path, phonemes)
|
||||
return phonemes
|
||||
|
||||
@staticmethod
|
||||
def _load_or_generate_phoneme_sequence(wav_file, text, phoneme_cache_path,
|
||||
enable_eos_bos, cleaners, language,
|
||||
tp, add_blank):
|
||||
def _load_or_generate_phoneme_sequence(
|
||||
wav_file, text, phoneme_cache_path, enable_eos_bos, cleaners, language, tp, add_blank
|
||||
):
|
||||
file_name = os.path.splitext(os.path.basename(wav_file))[0]
|
||||
|
||||
# different names for normal phonemes and with blank chars.
|
||||
file_name_ext = '_blanked_phoneme.npy' if add_blank else '_phoneme.npy'
|
||||
cache_path = os.path.join(phoneme_cache_path,
|
||||
file_name + file_name_ext)
|
||||
file_name_ext = "_blanked_phoneme.npy" if add_blank else "_phoneme.npy"
|
||||
cache_path = os.path.join(phoneme_cache_path, file_name + file_name_ext)
|
||||
try:
|
||||
phonemes = np.load(cache_path)
|
||||
except FileNotFoundError:
|
||||
phonemes = MyDataset._generate_and_cache_phoneme_sequence(
|
||||
text, cache_path, cleaners, language, tp, add_blank)
|
||||
text, cache_path, cleaners, language, tp, add_blank
|
||||
)
|
||||
except (ValueError, IOError):
|
||||
print(" [!] failed loading phonemes for {}. "
|
||||
"Recomputing.".format(wav_file))
|
||||
print(" [!] failed loading phonemes for {}. " "Recomputing.".format(wav_file))
|
||||
phonemes = MyDataset._generate_and_cache_phoneme_sequence(
|
||||
text, cache_path, cleaners, language, tp, add_blank)
|
||||
text, cache_path, cleaners, language, tp, add_blank
|
||||
)
|
||||
if enable_eos_bos:
|
||||
phonemes = pad_with_eos_bos(phonemes, tp=tp)
|
||||
phonemes = np.asarray(phonemes, dtype=np.int32)
|
||||
|
@ -150,15 +147,20 @@ class MyDataset(Dataset):
|
|||
if not self.input_seq_computed:
|
||||
if self.use_phonemes:
|
||||
text = self._load_or_generate_phoneme_sequence(
|
||||
wav_file, text, self.phoneme_cache_path,
|
||||
self.enable_eos_bos, self.cleaners, self.phoneme_language,
|
||||
self.tp, self.add_blank)
|
||||
wav_file,
|
||||
text,
|
||||
self.phoneme_cache_path,
|
||||
self.enable_eos_bos,
|
||||
self.cleaners,
|
||||
self.phoneme_language,
|
||||
self.tp,
|
||||
self.add_blank,
|
||||
)
|
||||
|
||||
else:
|
||||
text = np.asarray(text_to_sequence(text, [self.cleaners],
|
||||
tp=self.tp,
|
||||
add_blank=self.add_blank),
|
||||
dtype=np.int32)
|
||||
text = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners], tp=self.tp, add_blank=self.add_blank), dtype=np.int32
|
||||
)
|
||||
|
||||
assert text.size > 0, self.items[idx][1]
|
||||
assert wav.size > 0, self.items[idx][1]
|
||||
|
@ -173,12 +175,12 @@ class MyDataset(Dataset):
|
|||
return self.load_data(100)
|
||||
|
||||
sample = {
|
||||
'text': text,
|
||||
'wav': wav,
|
||||
'attn': attn,
|
||||
'item_idx': self.items[idx][1],
|
||||
'speaker_name': speaker_name,
|
||||
'wav_file_name': os.path.basename(wav_file)
|
||||
"text": text,
|
||||
"wav": wav,
|
||||
"attn": attn,
|
||||
"item_idx": self.items[idx][1],
|
||||
"speaker_name": speaker_name,
|
||||
"wav_file_name": os.path.basename(wav_file),
|
||||
}
|
||||
return sample
|
||||
|
||||
|
@ -187,8 +189,7 @@ class MyDataset(Dataset):
|
|||
item = args[0]
|
||||
func_args = args[1]
|
||||
text, wav_file, *_ = item
|
||||
phonemes = MyDataset._load_or_generate_phoneme_sequence(
|
||||
wav_file, text, *func_args)
|
||||
phonemes = MyDataset._load_or_generate_phoneme_sequence(wav_file, text, *func_args)
|
||||
return phonemes
|
||||
|
||||
def compute_input_seq(self, num_workers=0):
|
||||
|
@ -199,17 +200,19 @@ class MyDataset(Dataset):
|
|||
print(" | > Computing input sequences ...")
|
||||
for idx, item in enumerate(tqdm.tqdm(self.items)):
|
||||
text, *_ = item
|
||||
sequence = np.asarray(text_to_sequence(
|
||||
text, [self.cleaners],
|
||||
tp=self.tp,
|
||||
add_blank=self.add_blank),
|
||||
dtype=np.int32)
|
||||
sequence = np.asarray(
|
||||
text_to_sequence(text, [self.cleaners], tp=self.tp, add_blank=self.add_blank), dtype=np.int32
|
||||
)
|
||||
self.items[idx][0] = sequence
|
||||
|
||||
else:
|
||||
func_args = [
|
||||
self.phoneme_cache_path, self.enable_eos_bos, self.cleaners,
|
||||
self.phoneme_language, self.tp, self.add_blank
|
||||
self.phoneme_cache_path,
|
||||
self.enable_eos_bos,
|
||||
self.cleaners,
|
||||
self.phoneme_language,
|
||||
self.tp,
|
||||
self.add_blank,
|
||||
]
|
||||
if self.verbose:
|
||||
print(" | > Computing phonemes ...")
|
||||
|
@ -220,10 +223,11 @@ class MyDataset(Dataset):
|
|||
else:
|
||||
with Pool(num_workers) as p:
|
||||
phonemes = list(
|
||||
tqdm.tqdm(p.imap(MyDataset._phoneme_worker,
|
||||
[[item, func_args]
|
||||
for item in self.items]),
|
||||
total=len(self.items)))
|
||||
tqdm.tqdm(
|
||||
p.imap(MyDataset._phoneme_worker, [[item, func_args] for item in self.items]),
|
||||
total=len(self.items),
|
||||
)
|
||||
)
|
||||
for idx, p in enumerate(phonemes):
|
||||
self.items[idx][0] = p
|
||||
|
||||
|
@ -255,8 +259,10 @@ class MyDataset(Dataset):
|
|||
print(" | > Min length sequence: {}".format(np.min(lengths)))
|
||||
print(" | > Avg length sequence: {}".format(np.mean(lengths)))
|
||||
print(
|
||||
" | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}"
|
||||
.format(self.max_seq_len, self.min_seq_len, len(ignored)))
|
||||
" | > Num. instances discarded by max-min (max={}, min={}) seq limits: {}".format(
|
||||
self.max_seq_len, self.min_seq_len, len(ignored)
|
||||
)
|
||||
)
|
||||
print(" | > Batch group size: {}.".format(self.batch_group_size))
|
||||
|
||||
def __len__(self):
|
||||
|
@ -267,11 +273,11 @@ class MyDataset(Dataset):
|
|||
|
||||
def collate_fn(self, batch):
|
||||
r"""
|
||||
Perform preprocessing and create a final data batch:
|
||||
1. Sort batch instances by text-length
|
||||
2. Convert Audio signal to Spectrograms.
|
||||
3. PAD sequences wrt r.
|
||||
4. Load to Torch.
|
||||
Perform preprocessing and create a final data batch:
|
||||
1. Sort batch instances by text-length
|
||||
2. Convert Audio signal to Spectrograms.
|
||||
3. PAD sequences wrt r.
|
||||
4. Load to Torch.
|
||||
"""
|
||||
|
||||
# Puts each data field into a tensor with outer dimension batch size
|
||||
|
@ -280,44 +286,29 @@ class MyDataset(Dataset):
|
|||
text_lenghts = np.array([len(d["text"]) for d in batch])
|
||||
|
||||
# sort items with text input length for RNN efficiency
|
||||
text_lenghts, ids_sorted_decreasing = torch.sort(
|
||||
torch.LongTensor(text_lenghts), dim=0, descending=True)
|
||||
text_lenghts, ids_sorted_decreasing = torch.sort(torch.LongTensor(text_lenghts), dim=0, descending=True)
|
||||
|
||||
wav = [batch[idx]['wav'] for idx in ids_sorted_decreasing]
|
||||
item_idxs = [
|
||||
batch[idx]['item_idx'] for idx in ids_sorted_decreasing
|
||||
]
|
||||
text = [batch[idx]['text'] for idx in ids_sorted_decreasing]
|
||||
wav = [batch[idx]["wav"] for idx in ids_sorted_decreasing]
|
||||
item_idxs = [batch[idx]["item_idx"] for idx in ids_sorted_decreasing]
|
||||
text = [batch[idx]["text"] for idx in ids_sorted_decreasing]
|
||||
|
||||
speaker_name = [
|
||||
batch[idx]['speaker_name'] for idx in ids_sorted_decreasing
|
||||
]
|
||||
speaker_name = [batch[idx]["speaker_name"] for idx in ids_sorted_decreasing]
|
||||
# get speaker embeddings
|
||||
if self.speaker_mapping is not None:
|
||||
wav_files_names = [
|
||||
batch[idx]['wav_file_name']
|
||||
for idx in ids_sorted_decreasing
|
||||
]
|
||||
speaker_embedding = [
|
||||
self.speaker_mapping[w]['embedding']
|
||||
for w in wav_files_names
|
||||
]
|
||||
wav_files_names = [batch[idx]["wav_file_name"] for idx in ids_sorted_decreasing]
|
||||
speaker_embedding = [self.speaker_mapping[w]["embedding"] for w in wav_files_names]
|
||||
else:
|
||||
speaker_embedding = None
|
||||
# compute features
|
||||
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||
mel = [self.ap.melspectrogram(w).astype("float32") for w in wav]
|
||||
|
||||
mel_lengths = [m.shape[1] for m in mel]
|
||||
|
||||
# compute 'stop token' targets
|
||||
stop_targets = [
|
||||
np.array([0.] * (mel_len - 1) + [1.])
|
||||
for mel_len in mel_lengths
|
||||
]
|
||||
stop_targets = [np.array([0.0] * (mel_len - 1) + [1.0]) for mel_len in mel_lengths]
|
||||
|
||||
# PAD stop targets
|
||||
stop_targets = prepare_stop_target(stop_targets,
|
||||
self.outputs_per_step)
|
||||
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
|
||||
|
||||
# PAD sequences with longest instance in the batch
|
||||
text = prepare_data(text).astype(np.int32)
|
||||
|
@ -340,9 +331,7 @@ class MyDataset(Dataset):
|
|||
|
||||
# compute linear spectrogram
|
||||
if self.compute_linear_spec:
|
||||
linear = [
|
||||
self.ap.spectrogram(w).astype('float32') for w in wav
|
||||
]
|
||||
linear = [self.ap.spectrogram(w).astype("float32") for w in wav]
|
||||
linear = prepare_tensor(linear, self.outputs_per_step)
|
||||
linear = linear.transpose(0, 2, 1)
|
||||
assert mel.shape[1] == linear.shape[1]
|
||||
|
@ -351,8 +340,8 @@ class MyDataset(Dataset):
|
|||
linear = None
|
||||
|
||||
# collate attention alignments
|
||||
if batch[0]['attn'] is not None:
|
||||
attns = [batch[idx]['attn'].T for idx in ids_sorted_decreasing]
|
||||
if batch[0]["attn"] is not None:
|
||||
attns = [batch[idx]["attn"].T for idx in ids_sorted_decreasing]
|
||||
for idx, attn in enumerate(attns):
|
||||
pad2 = mel.shape[1] - attn.shape[1]
|
||||
pad1 = text.shape[1] - attn.shape[0]
|
||||
|
@ -362,8 +351,24 @@ class MyDataset(Dataset):
|
|||
attns = torch.FloatTensor(attns).unsqueeze(1)
|
||||
else:
|
||||
attns = None
|
||||
return text, text_lenghts, speaker_name, linear, mel, mel_lengths, \
|
||||
stop_targets, item_idxs, speaker_embedding, attns
|
||||
return (
|
||||
text,
|
||||
text_lenghts,
|
||||
speaker_name,
|
||||
linear,
|
||||
mel,
|
||||
mel_lengths,
|
||||
stop_targets,
|
||||
item_idxs,
|
||||
speaker_embedding,
|
||||
attns,
|
||||
)
|
||||
|
||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||
found {}".format(type(batch[0]))))
|
||||
raise TypeError(
|
||||
(
|
||||
"batch must contain tensors, numbers, dicts or lists;\
|
||||
found {}".format(
|
||||
type(batch[0])
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
|
@ -13,14 +13,15 @@ from TTS.tts.utils.generic_utils import split_dataset
|
|||
# UTILITIES
|
||||
####################
|
||||
|
||||
|
||||
def load_meta_data(datasets, eval_split=True):
|
||||
meta_data_train_all = []
|
||||
meta_data_eval_all = [] if eval_split else None
|
||||
for dataset in datasets:
|
||||
name = dataset['name']
|
||||
root_path = dataset['path']
|
||||
meta_file_train = dataset['meta_file_train']
|
||||
meta_file_val = dataset['meta_file_val']
|
||||
name = dataset["name"]
|
||||
root_path = dataset["path"]
|
||||
meta_file_train = dataset["meta_file_train"]
|
||||
meta_file_val = dataset["meta_file_val"]
|
||||
# setup the right data processor
|
||||
preprocessor = get_preprocessor_by_name(name)
|
||||
# load train set
|
||||
|
@ -35,8 +36,8 @@ def load_meta_data(datasets, eval_split=True):
|
|||
meta_data_eval_all += meta_data_eval
|
||||
meta_data_train_all += meta_data_train
|
||||
# load attention masks for duration predictor training
|
||||
if 'meta_file_attn_mask' in dataset and dataset['meta_file_attn_mask'] is not None:
|
||||
meta_data = dict(load_attention_mask_meta_data(dataset['meta_file_attn_mask']))
|
||||
if "meta_file_attn_mask" in dataset and dataset["meta_file_attn_mask"] is not None:
|
||||
meta_data = dict(load_attention_mask_meta_data(dataset["meta_file_attn_mask"]))
|
||||
for idx, ins in enumerate(meta_data_train_all):
|
||||
attn_file = meta_data[ins[1]].strip()
|
||||
meta_data_train_all[idx].append(attn_file)
|
||||
|
@ -49,12 +50,12 @@ def load_meta_data(datasets, eval_split=True):
|
|||
|
||||
def load_attention_mask_meta_data(metafile_path):
|
||||
"""Load meta data file created by compute_attention_masks.py"""
|
||||
with open(metafile_path, 'r') as f:
|
||||
with open(metafile_path, "r") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
meta_data = []
|
||||
for line in lines:
|
||||
wav_file, attn_file = line.split('|')
|
||||
wav_file, attn_file = line.split("|")
|
||||
meta_data.append([wav_file, attn_file])
|
||||
return meta_data
|
||||
|
||||
|
@ -69,6 +70,7 @@ def get_preprocessor_by_name(name):
|
|||
# DATASETS
|
||||
########################
|
||||
|
||||
|
||||
def tweb(root_path, meta_file):
|
||||
"""Normalize TWEB dataset.
|
||||
https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset
|
||||
|
@ -76,10 +78,10 @@ def tweb(root_path, meta_file):
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "tweb"
|
||||
with open(txt_file, 'r') as ttf:
|
||||
with open(txt_file, "r") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split('\t')
|
||||
wav_file = os.path.join(root_path, cols[0] + '.wav')
|
||||
cols = line.split("\t")
|
||||
wav_file = os.path.join(root_path, cols[0] + ".wav")
|
||||
text = cols[1]
|
||||
items.append([text, wav_file, speaker_name])
|
||||
return items
|
||||
|
@ -90,9 +92,9 @@ def mozilla(root_path, meta_file):
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "mozilla"
|
||||
with open(txt_file, 'r') as ttf:
|
||||
with open(txt_file, "r") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split('|')
|
||||
cols = line.split("|")
|
||||
wav_file = cols[1].strip()
|
||||
text = cols[0].strip()
|
||||
wav_file = os.path.join(root_path, "wavs", wav_file)
|
||||
|
@ -105,9 +107,9 @@ def mozilla_de(root_path, meta_file):
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "mozilla"
|
||||
with open(txt_file, 'r', encoding="ISO 8859-1") as ttf:
|
||||
with open(txt_file, "r", encoding="ISO 8859-1") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.strip().split('|')
|
||||
cols = line.strip().split("|")
|
||||
wav_file = cols[0].strip()
|
||||
text = cols[1].strip()
|
||||
folder_name = f"BATCH_{wav_file.split('_')[0]}_FINAL"
|
||||
|
@ -118,8 +120,7 @@ def mozilla_de(root_path, meta_file):
|
|||
|
||||
def mailabs(root_path, meta_files=None):
|
||||
"""Normalizes M-AI-Labs meta data files to TTS format"""
|
||||
speaker_regex = re.compile(
|
||||
"by_book/(male|female)/(?P<speaker_name>[^/]+)/")
|
||||
speaker_regex = re.compile("by_book/(male|female)/(?P<speaker_name>[^/]+)/")
|
||||
if meta_files is None:
|
||||
csv_files = glob(root_path + "/**/metadata.csv", recursive=True)
|
||||
else:
|
||||
|
@ -135,21 +136,18 @@ def mailabs(root_path, meta_files=None):
|
|||
continue
|
||||
speaker_name = speaker_name_match.group("speaker_name")
|
||||
print(" | > {}".format(csv_file))
|
||||
with open(txt_file, 'r') as ttf:
|
||||
with open(txt_file, "r") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split('|')
|
||||
cols = line.split("|")
|
||||
if meta_files is None:
|
||||
wav_file = os.path.join(folder, 'wavs', cols[0] + '.wav')
|
||||
wav_file = os.path.join(folder, "wavs", cols[0] + ".wav")
|
||||
else:
|
||||
wav_file = os.path.join(root_path,
|
||||
folder.replace("metadata.csv", ""),
|
||||
'wavs', cols[0] + '.wav')
|
||||
wav_file = os.path.join(root_path, folder.replace("metadata.csv", ""), "wavs", cols[0] + ".wav")
|
||||
if os.path.isfile(wav_file):
|
||||
text = cols[1].strip()
|
||||
items.append([text, wav_file, speaker_name])
|
||||
else:
|
||||
raise RuntimeError("> File %s does not exist!" %
|
||||
(wav_file))
|
||||
raise RuntimeError("> File %s does not exist!" % (wav_file))
|
||||
return items
|
||||
|
||||
|
||||
|
@ -159,10 +157,10 @@ def ljspeech(root_path, meta_file):
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "ljspeech"
|
||||
with open(txt_file, 'r', encoding="utf-8") as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split('|')
|
||||
wav_file = os.path.join(root_path, 'wavs', cols[0] + '.wav')
|
||||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0] + ".wav")
|
||||
text = cols[1]
|
||||
items.append([text, wav_file, speaker_name])
|
||||
return items
|
||||
|
@ -171,15 +169,15 @@ def ljspeech(root_path, meta_file):
|
|||
def sam_accenture(root_path, meta_file):
|
||||
"""Normalizes the sam-accenture meta data file to TTS format
|
||||
https://github.com/Sam-Accenture-Non-Binary-Voice/non-binary-voice-files"""
|
||||
xml_file = os.path.join(root_path, 'voice_over_recordings', meta_file)
|
||||
xml_file = os.path.join(root_path, "voice_over_recordings", meta_file)
|
||||
xml_root = ET.parse(xml_file).getroot()
|
||||
items = []
|
||||
speaker_name = "sam_accenture"
|
||||
for item in xml_root.findall('./fileid'):
|
||||
for item in xml_root.findall("./fileid"):
|
||||
text = item.text
|
||||
wav_file = os.path.join(root_path, 'vo_voice_quality_transformation', item.get('id')+'.wav')
|
||||
wav_file = os.path.join(root_path, "vo_voice_quality_transformation", item.get("id") + ".wav")
|
||||
if not os.path.exists(wav_file):
|
||||
print(f' [!] {wav_file} in metafile does not exist. Skipping...')
|
||||
print(f" [!] {wav_file} in metafile does not exist. Skipping...")
|
||||
continue
|
||||
items.append([text, wav_file, speaker_name])
|
||||
return items
|
||||
|
@ -191,10 +189,10 @@ def ruslan(root_path, meta_file):
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "ljspeech"
|
||||
with open(txt_file, 'r', encoding="utf-8") as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split('|')
|
||||
wav_file = os.path.join(root_path, 'RUSLAN', cols[0] + '.wav')
|
||||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "RUSLAN", cols[0] + ".wav")
|
||||
text = cols[1]
|
||||
items.append([text, wav_file, speaker_name])
|
||||
return items
|
||||
|
@ -205,9 +203,9 @@ def css10(root_path, meta_file):
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "ljspeech"
|
||||
with open(txt_file, 'r') as ttf:
|
||||
with open(txt_file, "r") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split('|')
|
||||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, cols[0])
|
||||
text = cols[1]
|
||||
items.append([text, wav_file, speaker_name])
|
||||
|
@ -219,10 +217,10 @@ def nancy(root_path, meta_file):
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "nancy"
|
||||
with open(txt_file, 'r') as ttf:
|
||||
with open(txt_file, "r") as ttf:
|
||||
for line in ttf:
|
||||
utt_id = line.split()[1]
|
||||
text = line[line.find('"') + 1:line.rfind('"') - 1]
|
||||
text = line[line.find('"') + 1 : line.rfind('"') - 1]
|
||||
wav_file = os.path.join(root_path, "wavn", utt_id + ".wav")
|
||||
items.append([text, wav_file, speaker_name])
|
||||
return items
|
||||
|
@ -232,7 +230,7 @@ def common_voice(root_path, meta_file):
|
|||
"""Normalize the common voice meta data file to TTS format."""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
with open(txt_file, 'r') as ttf:
|
||||
with open(txt_file, "r") as ttf:
|
||||
for line in ttf:
|
||||
if line.startswith("client_id"):
|
||||
continue
|
||||
|
@ -240,7 +238,7 @@ def common_voice(root_path, meta_file):
|
|||
text = cols[2]
|
||||
speaker_name = cols[0]
|
||||
wav_file = os.path.join(root_path, "clips", cols[1].replace(".mp3", ".wav"))
|
||||
items.append([text, wav_file, 'MCV_' + speaker_name])
|
||||
items.append([text, wav_file, "MCV_" + speaker_name])
|
||||
return items
|
||||
|
||||
|
||||
|
@ -250,19 +248,18 @@ def libri_tts(root_path, meta_files=None):
|
|||
if meta_files is None:
|
||||
meta_files = glob(f"{root_path}/**/*trans.tsv", recursive=True)
|
||||
for meta_file in meta_files:
|
||||
_meta_file = os.path.basename(meta_file).split('.')[0]
|
||||
speaker_name = _meta_file.split('_')[0]
|
||||
chapter_id = _meta_file.split('_')[1]
|
||||
_meta_file = os.path.basename(meta_file).split(".")[0]
|
||||
speaker_name = _meta_file.split("_")[0]
|
||||
chapter_id = _meta_file.split("_")[1]
|
||||
_root_path = os.path.join(root_path, f"{speaker_name}/{chapter_id}")
|
||||
with open(meta_file, 'r') as ttf:
|
||||
with open(meta_file, "r") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split('\t')
|
||||
wav_file = os.path.join(_root_path, cols[0] + '.wav')
|
||||
cols = line.split("\t")
|
||||
wav_file = os.path.join(_root_path, cols[0] + ".wav")
|
||||
text = cols[1]
|
||||
items.append([text, wav_file, 'LTTS_' + speaker_name])
|
||||
items.append([text, wav_file, "LTTS_" + speaker_name])
|
||||
for item in items:
|
||||
assert os.path.exists(
|
||||
item[1]), f" [!] wav files don't exist - {item[1]}"
|
||||
assert os.path.exists(item[1]), f" [!] wav files don't exist - {item[1]}"
|
||||
return items
|
||||
|
||||
|
||||
|
@ -271,11 +268,10 @@ def custom_turkish(root_path, meta_file):
|
|||
items = []
|
||||
speaker_name = "turkish-female"
|
||||
skipped_files = []
|
||||
with open(txt_file, 'r', encoding='utf-8') as ttf:
|
||||
with open(txt_file, "r", encoding="utf-8") as ttf:
|
||||
for line in ttf:
|
||||
cols = line.split('|')
|
||||
wav_file = os.path.join(root_path, 'wavs',
|
||||
cols[0].strip() + '.wav')
|
||||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, "wavs", cols[0].strip() + ".wav")
|
||||
if not os.path.exists(wav_file):
|
||||
skipped_files.append(wav_file)
|
||||
continue
|
||||
|
@ -287,14 +283,14 @@ def custom_turkish(root_path, meta_file):
|
|||
|
||||
# ToDo: add the dataset link when the dataset is released publicly
|
||||
def brspeech(root_path, meta_file):
|
||||
'''BRSpeech 3.0 beta'''
|
||||
"""BRSpeech 3.0 beta"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
with open(txt_file, 'r') as ttf:
|
||||
with open(txt_file, "r") as ttf:
|
||||
for line in ttf:
|
||||
if line.startswith("wav_filename"):
|
||||
continue
|
||||
cols = line.split('|')
|
||||
cols = line.split("|")
|
||||
wav_file = os.path.join(root_path, cols[0])
|
||||
text = cols[2]
|
||||
speaker_name = cols[3]
|
||||
|
@ -302,45 +298,41 @@ def brspeech(root_path, meta_file):
|
|||
return items
|
||||
|
||||
|
||||
def vctk(root_path, meta_files=None, wavs_path='wav48'):
|
||||
def vctk(root_path, meta_files=None, wavs_path="wav48"):
|
||||
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
|
||||
test_speakers = meta_files
|
||||
items = []
|
||||
meta_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
|
||||
for meta_file in meta_files:
|
||||
_, speaker_id, txt_file = os.path.relpath(meta_file,
|
||||
root_path).split(os.sep)
|
||||
file_id = txt_file.split('.')[0]
|
||||
if isinstance(test_speakers,
|
||||
list): # if is list ignore this speakers ids
|
||||
_, speaker_id, txt_file = os.path.relpath(meta_file, root_path).split(os.sep)
|
||||
file_id = txt_file.split(".")[0]
|
||||
if isinstance(test_speakers, list): # if is list ignore this speakers ids
|
||||
if speaker_id in test_speakers:
|
||||
continue
|
||||
with open(meta_file) as file_text:
|
||||
text = file_text.readlines()[0]
|
||||
wav_file = os.path.join(root_path, wavs_path, speaker_id,
|
||||
file_id + '.wav')
|
||||
items.append([text, wav_file, 'VCTK_' + speaker_id])
|
||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
|
||||
items.append([text, wav_file, "VCTK_" + speaker_id])
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def vctk_slim(root_path, meta_files=None, wavs_path='wav48'):
|
||||
def vctk_slim(root_path, meta_files=None, wavs_path="wav48"):
|
||||
"""homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz"""
|
||||
items = []
|
||||
txt_files = glob(f"{os.path.join(root_path,'txt')}/**/*.txt", recursive=True)
|
||||
for text_file in txt_files:
|
||||
_, speaker_id, txt_file = os.path.relpath(text_file,
|
||||
root_path).split(os.sep)
|
||||
file_id = txt_file.split('.')[0]
|
||||
_, speaker_id, txt_file = os.path.relpath(text_file, root_path).split(os.sep)
|
||||
file_id = txt_file.split(".")[0]
|
||||
if isinstance(meta_files, list): # if is list ignore this speakers ids
|
||||
if speaker_id in meta_files:
|
||||
continue
|
||||
wav_file = os.path.join(root_path, wavs_path, speaker_id,
|
||||
file_id + '.wav')
|
||||
items.append([None, wav_file, 'VCTK_' + speaker_id])
|
||||
wav_file = os.path.join(root_path, wavs_path, speaker_id, file_id + ".wav")
|
||||
items.append([None, wav_file, "VCTK_" + speaker_id])
|
||||
|
||||
return items
|
||||
|
||||
|
||||
# ======================================== VOX CELEB ===========================================
|
||||
def voxceleb2(root_path, meta_file=None):
|
||||
"""
|
||||
|
@ -365,31 +357,33 @@ def _voxcel_x(root_path, meta_file, voxcel_idx):
|
|||
|
||||
# if not exists meta file, crawl recursively for 'wav' files
|
||||
if meta_file is not None:
|
||||
with open(str(meta_file), 'r') as f:
|
||||
return [x.strip().split('|') for x in f.readlines()]
|
||||
with open(str(meta_file), "r") as f:
|
||||
return [x.strip().split("|") for x in f.readlines()]
|
||||
|
||||
elif not cache_to.exists():
|
||||
cnt = 0
|
||||
meta_data = []
|
||||
wav_files = voxceleb_path.rglob("**/*.wav")
|
||||
for path in tqdm(wav_files, desc=f"Building VoxCeleb {voxcel_idx} Meta file ... this needs to be done only once.",
|
||||
total=expected_count):
|
||||
for path in tqdm(
|
||||
wav_files,
|
||||
desc=f"Building VoxCeleb {voxcel_idx} Meta file ... this needs to be done only once.",
|
||||
total=expected_count,
|
||||
):
|
||||
speaker_id = str(Path(path).parent.parent.stem)
|
||||
assert speaker_id.startswith('id')
|
||||
assert speaker_id.startswith("id")
|
||||
text = None # VoxCel does not provide transciptions, and they are not needed for training the SE
|
||||
meta_data.append(f"{text}|{path}|voxcel{voxcel_idx}_{speaker_id}\n")
|
||||
cnt += 1
|
||||
with open(str(cache_to), 'w') as f:
|
||||
with open(str(cache_to), "w") as f:
|
||||
f.write("".join(meta_data))
|
||||
if cnt < expected_count:
|
||||
raise ValueError(f"Found too few instances for Voxceleb. Should be around {expected_count}, is: {cnt}")
|
||||
|
||||
with open(str(cache_to), 'r') as f:
|
||||
return [x.strip().split('|') for x in f.readlines()]
|
||||
with open(str(cache_to), "r") as f:
|
||||
return [x.strip().split("|") for x in f.readlines()]
|
||||
|
||||
|
||||
|
||||
def baker(root_path: str, meta_file: str) -> List[List[str]]:
|
||||
def baker(root_path: str, meta_file: str) -> List[List[str]]:
|
||||
"""Normalizes the Baker meta data file to TTS format
|
||||
|
||||
Args:
|
||||
|
@ -401,9 +395,9 @@ def baker(root_path: str, meta_file: str) -> List[List[str]]:
|
|||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
speaker_name = "baker"
|
||||
with open(txt_file, 'r') as ttf:
|
||||
with open(txt_file, "r") as ttf:
|
||||
for line in ttf:
|
||||
wav_name, text = line.rstrip('\n').split("|")
|
||||
wav_name, text = line.rstrip("\n").split("|")
|
||||
wav_path = os.path.join(root_path, "clips_22", wav_name)
|
||||
items.append([text, wav_path, speaker_name])
|
||||
return items
|
||||
|
|
|
@ -5,6 +5,7 @@ class MDNBlock(nn.Module):
|
|||
"""Mixture of Density Network implementation
|
||||
https://arxiv.org/pdf/2003.01950.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
|
@ -24,6 +25,6 @@ class MDNBlock(nn.Module):
|
|||
mu_sigma = self.conv2(o)
|
||||
# TODO: check this sigmoid
|
||||
# mu = torch.sigmoid(mu_sigma[:, :self.out_channels//2, :])
|
||||
mu = mu_sigma[:, :self.out_channels//2, :]
|
||||
log_sigma = mu_sigma[:, self.out_channels//2:, :]
|
||||
mu = mu_sigma[:, : self.out_channels // 2, :]
|
||||
log_sigma = mu_sigma[:, self.out_channels // 2 :, :]
|
||||
return mu, log_sigma
|
||||
|
|
|
@ -31,15 +31,16 @@ class WaveNetDecoder(nn.Module):
|
|||
hidden_channels (int): number of hidden channels for prenet and postnet.
|
||||
params (dict): dictionary for residual convolutional blocks.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, c_in_channels, params):
|
||||
super().__init__()
|
||||
# prenet
|
||||
self.prenet = torch.nn.Conv1d(in_channels, params['hidden_channels'], 1)
|
||||
self.prenet = torch.nn.Conv1d(in_channels, params["hidden_channels"], 1)
|
||||
# wavenet layers
|
||||
self.wn = WNBlocks(params['hidden_channels'], c_in_channels=c_in_channels, **params)
|
||||
self.wn = WNBlocks(params["hidden_channels"], c_in_channels=c_in_channels, **params)
|
||||
# postnet
|
||||
self.postnet = [
|
||||
torch.nn.Conv1d(params['hidden_channels'], hidden_channels, 1),
|
||||
torch.nn.Conv1d(params["hidden_channels"], hidden_channels, 1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv1d(hidden_channels, hidden_channels, 1),
|
||||
torch.nn.ReLU(),
|
||||
|
@ -77,12 +78,12 @@ class RelativePositionTransformerDecoder(nn.Module):
|
|||
hidden_channels (int): number of hidden channels including Transformer layers.
|
||||
params (dict): dictionary for residual convolutional blocks.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, params):
|
||||
|
||||
super().__init__()
|
||||
self.prenet = Conv1dBN(in_channels, hidden_channels, 1, 1)
|
||||
self.rel_pos_transformer = RelativePositionTransformer(
|
||||
in_channels, out_channels, hidden_channels, **params)
|
||||
self.rel_pos_transformer = RelativePositionTransformer(in_channels, out_channels, hidden_channels, **params)
|
||||
|
||||
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
|
||||
o = self.prenet(x) * x_mask
|
||||
|
@ -107,6 +108,7 @@ class FFTransformerDecoder(nn.Module):
|
|||
hidden_channels (int): number of hidden channels including Transformer layers.
|
||||
params (dict): dictionary for residual convolutional blocks.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, params):
|
||||
|
||||
super().__init__()
|
||||
|
@ -117,7 +119,7 @@ class FFTransformerDecoder(nn.Module):
|
|||
# TODO: handle multi-speaker
|
||||
x_mask = 1 if x_mask is None else x_mask
|
||||
o = self.transformer_block(x) * x_mask
|
||||
o = self.postnet(o)* x_mask
|
||||
o = self.postnet(o) * x_mask
|
||||
return o
|
||||
|
||||
|
||||
|
@ -141,19 +143,15 @@ class ResidualConv1dBNDecoder(nn.Module):
|
|||
hidden_channels (int): number of hidden channels including ResidualConv1dBNBlock layers.
|
||||
params (dict): dictionary for residual convolutional blocks.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, params):
|
||||
super().__init__()
|
||||
self.res_conv_block = ResidualConv1dBNBlock(in_channels,
|
||||
hidden_channels,
|
||||
hidden_channels, **params)
|
||||
self.res_conv_block = ResidualConv1dBNBlock(in_channels, hidden_channels, hidden_channels, **params)
|
||||
self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||
self.postnet = nn.Sequential(
|
||||
Conv1dBNBlock(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
params['kernel_size'],
|
||||
1,
|
||||
num_conv_blocks=2),
|
||||
Conv1dBNBlock(
|
||||
hidden_channels, hidden_channels, hidden_channels, params["kernel_size"], 1, num_conv_blocks=2
|
||||
),
|
||||
nn.Conv1d(hidden_channels, out_channels, 1),
|
||||
)
|
||||
|
||||
|
@ -178,17 +176,18 @@ class Decoder(nn.Module):
|
|||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
out_channels,
|
||||
in_hidden_channels,
|
||||
decoder_type='residual_conv_bn',
|
||||
decoder_params={
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4, 8] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 17
|
||||
},
|
||||
c_in_channels=0):
|
||||
self,
|
||||
out_channels,
|
||||
in_hidden_channels,
|
||||
decoder_type="residual_conv_bn",
|
||||
decoder_params={
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4, 8] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 17,
|
||||
},
|
||||
c_in_channels=0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if decoder_type.lower() == "relative_position_transformer":
|
||||
|
@ -196,23 +195,27 @@ class Decoder(nn.Module):
|
|||
in_channels=in_hidden_channels,
|
||||
out_channels=out_channels,
|
||||
hidden_channels=in_hidden_channels,
|
||||
params=decoder_params)
|
||||
elif decoder_type.lower() == 'residual_conv_bn':
|
||||
params=decoder_params,
|
||||
)
|
||||
elif decoder_type.lower() == "residual_conv_bn":
|
||||
self.decoder = ResidualConv1dBNDecoder(
|
||||
in_channels=in_hidden_channels,
|
||||
out_channels=out_channels,
|
||||
hidden_channels=in_hidden_channels,
|
||||
params=decoder_params)
|
||||
elif decoder_type.lower() == 'wavenet':
|
||||
self.decoder = WaveNetDecoder(in_channels=in_hidden_channels,
|
||||
out_channels=out_channels,
|
||||
hidden_channels=in_hidden_channels,
|
||||
c_in_channels=c_in_channels,
|
||||
params=decoder_params)
|
||||
elif decoder_type.lower() == 'fftransformer':
|
||||
params=decoder_params,
|
||||
)
|
||||
elif decoder_type.lower() == "wavenet":
|
||||
self.decoder = WaveNetDecoder(
|
||||
in_channels=in_hidden_channels,
|
||||
out_channels=out_channels,
|
||||
hidden_channels=in_hidden_channels,
|
||||
c_in_channels=c_in_channels,
|
||||
params=decoder_params,
|
||||
)
|
||||
elif decoder_type.lower() == "fftransformer":
|
||||
self.decoder = FFTransformerDecoder(in_hidden_channels, out_channels, decoder_params)
|
||||
else:
|
||||
raise ValueError(f'[!] Unknown decoder type - {decoder_type}')
|
||||
raise ValueError(f"[!] Unknown decoder type - {decoder_type}")
|
||||
|
||||
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
|
||||
"""
|
||||
|
|
|
@ -16,16 +16,19 @@ class DurationPredictor(nn.Module):
|
|||
Args:
|
||||
hidden_channels (int): number of channels in the inner layers.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_channels):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
Conv1dBN(hidden_channels, hidden_channels, 4, 1),
|
||||
Conv1dBN(hidden_channels, hidden_channels, 3, 1),
|
||||
Conv1dBN(hidden_channels, hidden_channels, 1, 1),
|
||||
nn.Conv1d(hidden_channels, 1, 1)
|
||||
])
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Conv1dBN(hidden_channels, hidden_channels, 4, 1),
|
||||
Conv1dBN(hidden_channels, hidden_channels, 3, 1),
|
||||
Conv1dBN(hidden_channels, hidden_channels, 1, 1),
|
||||
nn.Conv1d(hidden_channels, 1, 1),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer
|
||||
from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock
|
||||
from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock
|
||||
from TTS.tts.layers.generic.transformer import FFTransformerBlock
|
||||
|
||||
|
||||
|
@ -16,17 +16,19 @@ class RelativePositionTransformerEncoder(nn.Module):
|
|||
hidden_channels (int): number of hidden channels
|
||||
params (dict): dictionary for residual convolutional blocks.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, params):
|
||||
super().__init__()
|
||||
self.prenet = ResidualConv1dBNBlock(in_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size=5,
|
||||
num_res_blocks=3,
|
||||
num_conv_blocks=1,
|
||||
dilations=[1, 1, 1])
|
||||
self.rel_pos_transformer = RelativePositionTransformer(
|
||||
hidden_channels, out_channels, hidden_channels, **params)
|
||||
self.prenet = ResidualConv1dBNBlock(
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size=5,
|
||||
num_res_blocks=3,
|
||||
num_conv_blocks=1,
|
||||
dilations=[1, 1, 1],
|
||||
)
|
||||
self.rel_pos_transformer = RelativePositionTransformer(hidden_channels, out_channels, hidden_channels, **params)
|
||||
|
||||
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
|
||||
if x_mask is None:
|
||||
|
@ -47,20 +49,20 @@ class ResidualConv1dBNEncoder(nn.Module):
|
|||
hidden_channels (int): number of hidden channels
|
||||
params (dict): dictionary for residual convolutional blocks.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, params):
|
||||
super().__init__()
|
||||
self.prenet = nn.Sequential(nn.Conv1d(in_channels, hidden_channels, 1),
|
||||
nn.ReLU())
|
||||
self.res_conv_block = ResidualConv1dBNBlock(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels, **params)
|
||||
self.prenet = nn.Sequential(nn.Conv1d(in_channels, hidden_channels, 1), nn.ReLU())
|
||||
self.res_conv_block = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **params)
|
||||
|
||||
self.postnet = nn.Sequential(*[
|
||||
nn.Conv1d(hidden_channels, hidden_channels, 1),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(hidden_channels),
|
||||
nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
])
|
||||
self.postnet = nn.Sequential(
|
||||
*[
|
||||
nn.Conv1d(hidden_channels, hidden_channels, 1),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm1d(hidden_channels),
|
||||
nn.Conv1d(hidden_channels, out_channels, 1),
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument
|
||||
if x_mask is None:
|
||||
|
@ -115,18 +117,15 @@ class Encoder(nn.Module):
|
|||
}
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_hidden_channels,
|
||||
out_channels,
|
||||
encoder_type='residual_conv_bn',
|
||||
encoder_params={
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 13
|
||||
},
|
||||
c_in_channels=0):
|
||||
self,
|
||||
in_hidden_channels,
|
||||
out_channels,
|
||||
encoder_type="residual_conv_bn",
|
||||
encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13},
|
||||
c_in_channels=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
self.in_channels = in_hidden_channels
|
||||
|
@ -137,21 +136,22 @@ class Encoder(nn.Module):
|
|||
# init encoder
|
||||
if encoder_type.lower() == "relative_position_transformer":
|
||||
# text encoder
|
||||
# pylint: disable=unexpected-keyword-arg
|
||||
self.encoder = RelativePositionTransformerEncoder(
|
||||
in_hidden_channels, out_channels, in_hidden_channels,
|
||||
encoder_params) # pylint: disable=unexpected-keyword-arg
|
||||
elif encoder_type.lower() == 'residual_conv_bn':
|
||||
self.encoder = ResidualConv1dBNEncoder(in_hidden_channels,
|
||||
out_channels,
|
||||
in_hidden_channels,
|
||||
encoder_params)
|
||||
elif encoder_type.lower() == 'fftransformer':
|
||||
assert in_hidden_channels == out_channels, \
|
||||
"[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'"
|
||||
self.encoder = FFTransformerBlock(in_hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg
|
||||
in_hidden_channels, out_channels, in_hidden_channels, encoder_params
|
||||
)
|
||||
elif encoder_type.lower() == "residual_conv_bn":
|
||||
self.encoder = ResidualConv1dBNEncoder(in_hidden_channels, out_channels, in_hidden_channels, encoder_params)
|
||||
elif encoder_type.lower() == "fftransformer":
|
||||
assert (
|
||||
in_hidden_channels == out_channels
|
||||
), "[!] must be `in_channels` == `out_channels` when encoder type is 'fftransformer'"
|
||||
# pylint: disable=unexpected-keyword-arg
|
||||
self.encoder = FFTransformerBlock(
|
||||
in_hidden_channels, **encoder_params
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(' [!] unknown encoder type.')
|
||||
|
||||
raise NotImplementedError(" [!] unknown encoder type.")
|
||||
|
||||
def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument
|
||||
"""
|
||||
|
|
|
@ -10,6 +10,7 @@ class GatedConvBlock(nn.Module):
|
|||
kernel_size (int): convolution kernel size.
|
||||
dropout_p (float): dropout rate.
|
||||
"""
|
||||
|
||||
def __init__(self, in_out_channels, kernel_size, dropout_p, num_layers):
|
||||
super().__init__()
|
||||
# class arguments
|
||||
|
@ -20,21 +21,14 @@ class GatedConvBlock(nn.Module):
|
|||
self.norm_layers = nn.ModuleList()
|
||||
self.layers = nn.ModuleList()
|
||||
for _ in range(num_layers):
|
||||
self.conv_layers += [
|
||||
nn.Conv1d(in_out_channels,
|
||||
2 * in_out_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2)
|
||||
]
|
||||
self.conv_layers += [nn.Conv1d(in_out_channels, 2 * in_out_channels, kernel_size, padding=kernel_size // 2)]
|
||||
self.norm_layers += [LayerNorm(2 * in_out_channels)]
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
o = x
|
||||
res = x
|
||||
for idx in range(self.num_layers):
|
||||
o = nn.functional.dropout(o,
|
||||
p=self.dropout_p,
|
||||
training=self.training)
|
||||
o = nn.functional.dropout(o, p=self.dropout_p, training=self.training)
|
||||
o = self.conv_layers[idx](o * x_mask)
|
||||
o = self.norm_layers[idx](o)
|
||||
o = nn.functional.glu(o, dim=1)
|
||||
|
|
|
@ -22,24 +22,17 @@ class LayerNorm(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
mean = torch.mean(x, 1, keepdim=True)
|
||||
variance = torch.mean((x - mean)**2, 1, keepdim=True)
|
||||
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
|
||||
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
||||
x = x * self.gamma + self.beta
|
||||
return x
|
||||
|
||||
|
||||
class TemporalBatchNorm1d(nn.BatchNorm1d):
|
||||
"""Normalize each channel separately over time and batch.
|
||||
"""
|
||||
def __init__(self,
|
||||
channels,
|
||||
affine=True,
|
||||
track_running_stats=True,
|
||||
momentum=0.1):
|
||||
super().__init__(channels,
|
||||
affine=affine,
|
||||
track_running_stats=track_running_stats,
|
||||
momentum=momentum)
|
||||
"""Normalize each channel separately over time and batch."""
|
||||
|
||||
def __init__(self, channels, affine=True, track_running_stats=True, momentum=0.1):
|
||||
super().__init__(channels, affine=affine, track_running_stats=track_running_stats, momentum=momentum)
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.transpose(2, 1)).transpose(2, 1)
|
||||
|
@ -58,6 +51,7 @@ class ActNorm(nn.Module):
|
|||
- inputs: (B, C, T)
|
||||
- outputs: (B, C, T)
|
||||
"""
|
||||
|
||||
def __init__(self, channels, ddi=False, **kwargs): # pylint: disable=unused-argument
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
|
@ -68,8 +62,7 @@ class ActNorm(nn.Module):
|
|||
|
||||
def forward(self, x, x_mask=None, reverse=False, **kwargs): # pylint: disable=unused-argument
|
||||
if x_mask is None:
|
||||
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device,
|
||||
dtype=x.dtype)
|
||||
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
x_len = torch.sum(x_mask, [1, 2])
|
||||
if not self.initialized:
|
||||
self.initialize(x, x_mask)
|
||||
|
@ -95,13 +88,11 @@ class ActNorm(nn.Module):
|
|||
denom = torch.sum(x_mask, [0, 2])
|
||||
m = torch.sum(x * x_mask, [0, 2]) / denom
|
||||
m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
|
||||
v = m_sq - (m**2)
|
||||
v = m_sq - (m ** 2)
|
||||
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
|
||||
|
||||
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(
|
||||
dtype=self.bias.dtype)
|
||||
logs_init = (-logs).view(*self.logs.shape).to(
|
||||
dtype=self.logs.dtype)
|
||||
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
|
||||
logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
|
||||
|
||||
self.bias.data.copy_(bias_init)
|
||||
self.logs.data.copy_(logs_init)
|
||||
|
|
|
@ -11,20 +11,20 @@ class PositionalEncoding(nn.Module):
|
|||
channels (int): embedding size
|
||||
dropout (float): dropout parameter
|
||||
"""
|
||||
|
||||
def __init__(self, channels, dropout_p=0.0, max_len=5000):
|
||||
super().__init__()
|
||||
if channels % 2 != 0:
|
||||
raise ValueError(
|
||||
"Cannot use sin/cos positional encoding with "
|
||||
"odd channels (got channels={:d})".format(channels))
|
||||
"Cannot use sin/cos positional encoding with " "odd channels (got channels={:d})".format(channels)
|
||||
)
|
||||
pe = torch.zeros(max_len, channels)
|
||||
position = torch.arange(0, max_len).unsqueeze(1)
|
||||
div_term = torch.pow(10000,
|
||||
torch.arange(0, channels, 2).float() / channels)
|
||||
div_term = torch.pow(10000, torch.arange(0, channels, 2).float() / channels)
|
||||
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
||||
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(1, 2)
|
||||
self.register_buffer('pe', pe)
|
||||
self.register_buffer("pe", pe)
|
||||
if dropout_p > 0:
|
||||
self.dropout = nn.Dropout(p=dropout_p)
|
||||
self.channels = channels
|
||||
|
@ -43,14 +43,15 @@ class PositionalEncoding(nn.Module):
|
|||
if self.pe.size(2) < x.size(2):
|
||||
raise RuntimeError(
|
||||
f"Sequence is {x.size(2)} but PositionalEncoding is"
|
||||
f" limited to {self.pe.size(2)}. See max_len argument.")
|
||||
f" limited to {self.pe.size(2)}. See max_len argument."
|
||||
)
|
||||
if mask is not None:
|
||||
pos_enc = (self.pe[:, :, :x.size(2)] * mask)
|
||||
pos_enc = self.pe[:, :, : x.size(2)] * mask
|
||||
else:
|
||||
pos_enc = self.pe[:, :, :x.size(2)]
|
||||
pos_enc = self.pe[:, :, : x.size(2)]
|
||||
x = x + pos_enc
|
||||
else:
|
||||
x = x + self.pe[:, :, first_idx:last_idx]
|
||||
if hasattr(self, 'dropout'):
|
||||
if hasattr(self, "dropout"):
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
|
|
@ -3,9 +3,10 @@ from torch import nn
|
|||
|
||||
class ZeroTemporalPad(nn.Module):
|
||||
"""Pad sequences to equal lentgh in the temporal dimension"""
|
||||
|
||||
def __init__(self, kernel_size, dilation):
|
||||
super().__init__()
|
||||
total_pad = (dilation * (kernel_size - 1))
|
||||
total_pad = dilation * (kernel_size - 1)
|
||||
begin = total_pad // 2
|
||||
end = total_pad - begin
|
||||
self.pad_layer = nn.ZeroPad2d((0, 0, begin, end))
|
||||
|
@ -27,9 +28,10 @@ class Conv1dBN(nn.Module):
|
|||
kernel_size (int): kernel size for convolutional filters.
|
||||
dilation (int): dilation for convolution layers.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, dilation):
|
||||
super().__init__()
|
||||
padding = (dilation * (kernel_size - 1))
|
||||
padding = dilation * (kernel_size - 1)
|
||||
pad_s = padding // 2
|
||||
pad_e = padding - pad_s
|
||||
self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation)
|
||||
|
@ -55,14 +57,17 @@ class Conv1dBNBlock(nn.Module):
|
|||
dilation (int): dilation for convolution layers.
|
||||
num_conv_blocks (int, optional): number of convolutional blocks. Defaults to 2.
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation, num_conv_blocks=2):
|
||||
super().__init__()
|
||||
self.conv_bn_blocks = []
|
||||
for idx in range(num_conv_blocks):
|
||||
layer = Conv1dBN(in_channels if idx == 0 else hidden_channels,
|
||||
out_channels if idx == (num_conv_blocks - 1) else hidden_channels,
|
||||
kernel_size,
|
||||
dilation)
|
||||
layer = Conv1dBN(
|
||||
in_channels if idx == 0 else hidden_channels,
|
||||
out_channels if idx == (num_conv_blocks - 1) else hidden_channels,
|
||||
kernel_size,
|
||||
dilation,
|
||||
)
|
||||
self.conv_bn_blocks.append(layer)
|
||||
self.conv_bn_blocks = nn.Sequential(*self.conv_bn_blocks)
|
||||
|
||||
|
@ -91,18 +96,23 @@ class ResidualConv1dBNBlock(nn.Module):
|
|||
num_res_blocks (int, optional): number of residual blocks. Defaults to 13.
|
||||
num_conv_blocks (int, optional): number of convolutional blocks in each residual block. Defaults to 2.
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilations, num_res_blocks=13, num_conv_blocks=2):
|
||||
|
||||
def __init__(
|
||||
self, in_channels, out_channels, hidden_channels, kernel_size, dilations, num_res_blocks=13, num_conv_blocks=2
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
assert len(dilations) == num_res_blocks
|
||||
self.res_blocks = nn.ModuleList()
|
||||
for idx, dilation in enumerate(dilations):
|
||||
block = Conv1dBNBlock(in_channels if idx == 0 else hidden_channels,
|
||||
out_channels if (idx + 1) == len(dilations) else hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation,
|
||||
num_conv_blocks)
|
||||
block = Conv1dBNBlock(
|
||||
in_channels if idx == 0 else hidden_channels,
|
||||
out_channels if (idx + 1) == len(dilations) else hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation,
|
||||
num_conv_blocks,
|
||||
)
|
||||
self.res_blocks.append(block)
|
||||
|
||||
def forward(self, x, x_mask=None):
|
||||
|
|
|
@ -5,12 +5,8 @@ from torch import nn
|
|||
class TimeDepthSeparableConv(nn.Module):
|
||||
"""Time depth separable convolution as in https://arxiv.org/pdf/1904.02619.pdf
|
||||
It shows competative results with less computation and memory footprint."""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hid_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
bias=True):
|
||||
|
||||
def __init__(self, in_channels, hid_channels, out_channels, kernel_size, bias=True):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
|
@ -62,28 +58,24 @@ class TimeDepthSeparableConv(nn.Module):
|
|||
|
||||
|
||||
class TimeDepthSeparableConvBlock(nn.Module):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hid_channels,
|
||||
out_channels,
|
||||
num_layers,
|
||||
kernel_size,
|
||||
bias=True):
|
||||
def __init__(self, in_channels, hid_channels, out_channels, num_layers, kernel_size, bias=True):
|
||||
super().__init__()
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
assert num_layers > 1
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
layer = TimeDepthSeparableConv(
|
||||
in_channels, hid_channels,
|
||||
out_channels if num_layers == 1 else hid_channels, kernel_size,
|
||||
bias)
|
||||
in_channels, hid_channels, out_channels if num_layers == 1 else hid_channels, kernel_size, bias
|
||||
)
|
||||
self.layers.append(layer)
|
||||
for idx in range(num_layers - 1):
|
||||
layer = TimeDepthSeparableConv(
|
||||
hid_channels, hid_channels, out_channels if
|
||||
(idx + 1) == (num_layers - 1) else hid_channels, kernel_size,
|
||||
bias)
|
||||
hid_channels,
|
||||
hid_channels,
|
||||
out_channels if (idx + 1) == (num_layers - 1) else hid_channels,
|
||||
kernel_size,
|
||||
bias,
|
||||
)
|
||||
self.layers.append(layer)
|
||||
|
||||
def forward(self, x, mask):
|
||||
|
|
|
@ -4,16 +4,9 @@ import torch.nn.functional as F
|
|||
|
||||
|
||||
class FFTransformer(nn.Module):
|
||||
def __init__(self,
|
||||
in_out_channels,
|
||||
num_heads,
|
||||
hidden_channels_ffn=1024,
|
||||
kernel_size_fft=3,
|
||||
dropout_p=0.1):
|
||||
def __init__(self, in_out_channels, num_heads, hidden_channels_ffn=1024, kernel_size_fft=3, dropout_p=0.1):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(in_out_channels,
|
||||
num_heads,
|
||||
dropout=dropout_p)
|
||||
self.self_attn = nn.MultiheadAttention(in_out_channels, num_heads, dropout=dropout_p)
|
||||
|
||||
padding = (kernel_size_fft - 1) // 2
|
||||
self.conv1 = nn.Conv1d(in_out_channels, hidden_channels_ffn, kernel_size=kernel_size_fft, padding=padding)
|
||||
|
@ -27,11 +20,7 @@ class FFTransformer(nn.Module):
|
|||
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
||||
"""😦 ugly looking with all the transposing """
|
||||
src = src.permute(2, 0, 1)
|
||||
src2, enc_align = self.self_attn(src,
|
||||
src,
|
||||
src,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask)
|
||||
src2, enc_align = self.self_attn(src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)
|
||||
src = self.norm1(src + src2)
|
||||
# T x B x D -> B x D x T
|
||||
src = src.permute(1, 2, 0)
|
||||
|
@ -45,15 +34,19 @@ class FFTransformer(nn.Module):
|
|||
|
||||
|
||||
class FFTransformerBlock(nn.Module):
|
||||
def __init__(self, in_out_channels, num_heads, hidden_channels_ffn,
|
||||
num_layers, dropout_p):
|
||||
def __init__(self, in_out_channels, num_heads, hidden_channels_ffn, num_layers, dropout_p):
|
||||
super().__init__()
|
||||
self.fft_layers = nn.ModuleList([
|
||||
FFTransformer(in_out_channels=in_out_channels,
|
||||
num_heads=num_heads,
|
||||
hidden_channels_ffn=hidden_channels_ffn,
|
||||
dropout_p=dropout_p) for _ in range(num_layers)
|
||||
])
|
||||
self.fft_layers = nn.ModuleList(
|
||||
[
|
||||
FFTransformer(
|
||||
in_out_channels=in_out_channels,
|
||||
num_heads=num_heads,
|
||||
hidden_channels_ffn=hidden_channels_ffn,
|
||||
dropout_p=dropout_p,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, mask=None, g=None): # pylint: disable=unused-argument
|
||||
"""
|
||||
|
|
|
@ -32,15 +32,18 @@ class WN(torch.nn.Module):
|
|||
dropout_p (float): dropout rate.
|
||||
weight_norm (bool): enable/disable weight norm for convolution layers.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
num_layers,
|
||||
c_in_channels=0,
|
||||
dropout_p=0,
|
||||
weight_norm=True):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
num_layers,
|
||||
c_in_channels=0,
|
||||
dropout_p=0,
|
||||
weight_norm=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert kernel_size % 2 == 1
|
||||
assert hidden_channels % 2 == 0
|
||||
|
@ -58,20 +61,16 @@ class WN(torch.nn.Module):
|
|||
|
||||
# init conditioning layer
|
||||
if c_in_channels > 0:
|
||||
cond_layer = torch.nn.Conv1d(c_in_channels,
|
||||
2 * hidden_channels * num_layers, 1)
|
||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer,
|
||||
name='weight')
|
||||
cond_layer = torch.nn.Conv1d(c_in_channels, 2 * hidden_channels * num_layers, 1)
|
||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
||||
# intermediate layers
|
||||
for i in range(num_layers):
|
||||
dilation = dilation_rate**i
|
||||
dilation = dilation_rate ** i
|
||||
padding = int((kernel_size * dilation - dilation) / 2)
|
||||
in_layer = torch.nn.Conv1d(hidden_channels,
|
||||
2 * hidden_channels,
|
||||
kernel_size,
|
||||
dilation=dilation,
|
||||
padding=padding)
|
||||
in_layer = torch.nn.utils.weight_norm(in_layer, name='weight')
|
||||
in_layer = torch.nn.Conv1d(
|
||||
hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding
|
||||
)
|
||||
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
|
||||
self.in_layers.append(in_layer)
|
||||
|
||||
if i < num_layers - 1:
|
||||
|
@ -79,10 +78,8 @@ class WN(torch.nn.Module):
|
|||
else:
|
||||
res_skip_channels = hidden_channels
|
||||
|
||||
res_skip_layer = torch.nn.Conv1d(hidden_channels,
|
||||
res_skip_channels, 1)
|
||||
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer,
|
||||
name='weight')
|
||||
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
||||
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
|
||||
self.res_skip_layers.append(res_skip_layer)
|
||||
# setup weight norm
|
||||
if not weight_norm:
|
||||
|
@ -99,15 +96,14 @@ class WN(torch.nn.Module):
|
|||
x_in = self.dropout(x_in)
|
||||
if g is not None:
|
||||
cond_offset = i * 2 * self.hidden_channels
|
||||
g_l = g[:, cond_offset:cond_offset + 2 * self.hidden_channels, :]
|
||||
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
||||
else:
|
||||
g_l = torch.zeros_like(x_in)
|
||||
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l,
|
||||
n_channels_tensor)
|
||||
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
||||
res_skip_acts = self.res_skip_layers[i](acts)
|
||||
if i < self.num_layers - 1:
|
||||
x = (x + res_skip_acts[:, :self.hidden_channels, :]) * x_mask
|
||||
output = output + res_skip_acts[:, self.hidden_channels:, :]
|
||||
x = (x + res_skip_acts[:, : self.hidden_channels, :]) * x_mask
|
||||
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
||||
else:
|
||||
output = output + res_skip_acts
|
||||
return output * x_mask
|
||||
|
@ -140,28 +136,32 @@ class WNBlocks(nn.Module):
|
|||
weight_norm (bool): enable/disable weight norm for convolution layers.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
num_blocks,
|
||||
num_layers,
|
||||
c_in_channels=0,
|
||||
dropout_p=0,
|
||||
weight_norm=True):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
num_blocks,
|
||||
num_layers,
|
||||
c_in_channels=0,
|
||||
dropout_p=0,
|
||||
weight_norm=True,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.wn_blocks = nn.ModuleList()
|
||||
for idx in range(num_blocks):
|
||||
layer = WN(in_channels=in_channels if idx == 0 else hidden_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
kernel_size=kernel_size,
|
||||
dilation_rate=dilation_rate,
|
||||
num_layers=num_layers,
|
||||
c_in_channels=c_in_channels,
|
||||
dropout_p=dropout_p,
|
||||
weight_norm=weight_norm)
|
||||
layer = WN(
|
||||
in_channels=in_channels if idx == 0 else hidden_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
kernel_size=kernel_size,
|
||||
dilation_rate=dilation_rate,
|
||||
num_layers=num_layers,
|
||||
c_in_channels=c_in_channels,
|
||||
dropout_p=dropout_p,
|
||||
weight_norm=weight_norm,
|
||||
)
|
||||
self.wn_blocks.append(layer)
|
||||
|
||||
def forward(self, x, x_mask=None, g=None):
|
||||
|
|
|
@ -18,14 +18,12 @@ def squeeze(x, x_mask=None, num_sqz=2):
|
|||
t = (t // num_sqz) * num_sqz
|
||||
x = x[:, :, :t]
|
||||
x_sqz = x.view(b, c, t // num_sqz, num_sqz)
|
||||
x_sqz = x_sqz.permute(0, 3, 1,
|
||||
2).contiguous().view(b, c * num_sqz, t // num_sqz)
|
||||
x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * num_sqz, t // num_sqz)
|
||||
|
||||
if x_mask is not None:
|
||||
x_mask = x_mask[:, :, num_sqz - 1::num_sqz]
|
||||
x_mask = x_mask[:, :, num_sqz - 1 :: num_sqz]
|
||||
else:
|
||||
x_mask = torch.ones(b, 1, t // num_sqz).to(device=x.device,
|
||||
dtype=x.dtype)
|
||||
x_mask = torch.ones(b, 1, t // num_sqz).to(device=x.device, dtype=x.dtype)
|
||||
return x_sqz * x_mask, x_mask
|
||||
|
||||
|
||||
|
@ -34,20 +32,16 @@ def unsqueeze(x, x_mask=None, num_sqz=2):
|
|||
|
||||
Note:
|
||||
each 's' is a n-dimensional vector.
|
||||
[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]] """
|
||||
[[s1, s3, s5], [s2, s4, s6]] --> [[s1, s3, s5], [s2, s4, s6]]"""
|
||||
b, c, t = x.size()
|
||||
|
||||
x_unsqz = x.view(b, num_sqz, c // num_sqz, t)
|
||||
x_unsqz = x_unsqz.permute(0, 2, 3,
|
||||
1).contiguous().view(b, c // num_sqz,
|
||||
t * num_sqz)
|
||||
x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // num_sqz, t * num_sqz)
|
||||
|
||||
if x_mask is not None:
|
||||
x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1,
|
||||
num_sqz).view(b, 1, t * num_sqz)
|
||||
x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, num_sqz).view(b, 1, t * num_sqz)
|
||||
else:
|
||||
x_mask = torch.ones(b, 1, t * num_sqz).to(device=x.device,
|
||||
dtype=x.dtype)
|
||||
x_mask = torch.ones(b, 1, t * num_sqz).to(device=x.device, dtype=x.dtype)
|
||||
return x_unsqz * x_mask, x_mask
|
||||
|
||||
|
||||
|
@ -65,18 +59,21 @@ class Decoder(nn.Module):
|
|||
dropout_p (float): wavenet dropout rate.
|
||||
sigmoid_scale (bool): enable/disable sigmoid scaling in coupling layer.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
num_flow_blocks,
|
||||
num_coupling_layers,
|
||||
dropout_p=0.,
|
||||
num_splits=4,
|
||||
num_squeeze=2,
|
||||
sigmoid_scale=False,
|
||||
c_in_channels=0):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
num_flow_blocks,
|
||||
num_coupling_layers,
|
||||
dropout_p=0.0,
|
||||
num_splits=4,
|
||||
num_squeeze=2,
|
||||
sigmoid_scale=False,
|
||||
c_in_channels=0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
|
@ -94,18 +91,19 @@ class Decoder(nn.Module):
|
|||
self.flows = nn.ModuleList()
|
||||
for _ in range(num_flow_blocks):
|
||||
self.flows.append(ActNorm(channels=in_channels * num_squeeze))
|
||||
self.flows.append(InvConvNear(channels=in_channels * num_squeeze, num_splits=num_splits))
|
||||
self.flows.append(
|
||||
InvConvNear(channels=in_channels * num_squeeze,
|
||||
num_splits=num_splits))
|
||||
self.flows.append(
|
||||
CouplingBlock(in_channels * num_squeeze,
|
||||
hidden_channels,
|
||||
kernel_size=kernel_size,
|
||||
dilation_rate=dilation_rate,
|
||||
num_layers=num_coupling_layers,
|
||||
c_in_channels=c_in_channels,
|
||||
dropout_p=dropout_p,
|
||||
sigmoid_scale=sigmoid_scale))
|
||||
CouplingBlock(
|
||||
in_channels * num_squeeze,
|
||||
hidden_channels,
|
||||
kernel_size=kernel_size,
|
||||
dilation_rate=dilation_rate,
|
||||
num_layers=num_coupling_layers,
|
||||
c_in_channels=c_in_channels,
|
||||
dropout_p=dropout_p,
|
||||
sigmoid_scale=sigmoid_scale,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask, g=None, reverse=False):
|
||||
if not reverse:
|
||||
|
|
|
@ -14,6 +14,7 @@ class DurationPredictor(nn.Module):
|
|||
kernel_size ([type]): [description]
|
||||
dropout_p ([type]): [description]
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, hidden_channels, kernel_size, dropout_p):
|
||||
super().__init__()
|
||||
# class arguments
|
||||
|
@ -23,15 +24,9 @@ class DurationPredictor(nn.Module):
|
|||
self.dropout_p = dropout_p
|
||||
# layers
|
||||
self.drop = nn.Dropout(dropout_p)
|
||||
self.conv_1 = nn.Conv1d(in_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2)
|
||||
self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.norm_1 = LayerNorm(hidden_channels)
|
||||
self.conv_2 = nn.Conv1d(hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2)
|
||||
self.conv_2 = nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.norm_2 = LayerNorm(hidden_channels)
|
||||
# output layer
|
||||
self.proj = nn.Conv1d(hidden_channels, 1, 1)
|
||||
|
|
|
@ -69,17 +69,20 @@ class Encoder(nn.Module):
|
|||
'num_layers': 9,
|
||||
}
|
||||
"""
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
hidden_channels_dp,
|
||||
encoder_type,
|
||||
encoder_params,
|
||||
dropout_p_dp=0.1,
|
||||
mean_only=False,
|
||||
use_prenet=True,
|
||||
c_in_channels=0):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
hidden_channels_dp,
|
||||
encoder_type,
|
||||
encoder_params,
|
||||
dropout_p_dp=0.1,
|
||||
mean_only=False,
|
||||
use_prenet=True,
|
||||
c_in_channels=0,
|
||||
):
|
||||
super().__init__()
|
||||
# class arguments
|
||||
self.num_chars = num_chars
|
||||
|
@ -93,47 +96,33 @@ class Encoder(nn.Module):
|
|||
self.encoder_type = encoder_type
|
||||
# embedding layer
|
||||
self.emb = nn.Embedding(num_chars, hidden_channels)
|
||||
nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
|
||||
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
|
||||
# init encoder module
|
||||
if encoder_type.lower() == "rel_pos_transformer":
|
||||
if use_prenet:
|
||||
self.prenet = ResidualConv1dLayerNormBlock(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size=5,
|
||||
num_layers=3,
|
||||
dropout_p=0.5)
|
||||
self.encoder = RelativePositionTransformer(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
**encoder_params)
|
||||
elif encoder_type.lower() == 'gated_conv':
|
||||
self.encoder = GatedConvBlock(hidden_channels, **encoder_params)
|
||||
elif encoder_type.lower() == 'residual_conv_bn':
|
||||
if use_prenet:
|
||||
self.prenet = nn.Sequential(
|
||||
nn.Conv1d(hidden_channels, hidden_channels, 1),
|
||||
nn.ReLU()
|
||||
self.prenet = ResidualConv1dLayerNormBlock(
|
||||
hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5
|
||||
)
|
||||
self.encoder = ResidualConv1dBNBlock(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
**encoder_params)
|
||||
self.postnet = nn.Sequential(
|
||||
nn.Conv1d(self.hidden_channels, self.hidden_channels, 1),
|
||||
nn.BatchNorm1d(self.hidden_channels))
|
||||
elif encoder_type.lower() == 'time_depth_separable':
|
||||
self.encoder = RelativePositionTransformer(
|
||||
hidden_channels, hidden_channels, hidden_channels, **encoder_params
|
||||
)
|
||||
elif encoder_type.lower() == "gated_conv":
|
||||
self.encoder = GatedConvBlock(hidden_channels, **encoder_params)
|
||||
elif encoder_type.lower() == "residual_conv_bn":
|
||||
if use_prenet:
|
||||
self.prenet = ResidualConv1dLayerNormBlock(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size=5,
|
||||
num_layers=3,
|
||||
dropout_p=0.5)
|
||||
self.encoder = TimeDepthSeparableConvBlock(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
**encoder_params)
|
||||
self.prenet = nn.Sequential(nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU())
|
||||
self.encoder = ResidualConv1dBNBlock(hidden_channels, hidden_channels, hidden_channels, **encoder_params)
|
||||
self.postnet = nn.Sequential(
|
||||
nn.Conv1d(self.hidden_channels, self.hidden_channels, 1), nn.BatchNorm1d(self.hidden_channels)
|
||||
)
|
||||
elif encoder_type.lower() == "time_depth_separable":
|
||||
if use_prenet:
|
||||
self.prenet = ResidualConv1dLayerNormBlock(
|
||||
hidden_channels, hidden_channels, hidden_channels, kernel_size=5, num_layers=3, dropout_p=0.5
|
||||
)
|
||||
self.encoder = TimeDepthSeparableConvBlock(
|
||||
hidden_channels, hidden_channels, hidden_channels, **encoder_params
|
||||
)
|
||||
else:
|
||||
raise ValueError(" [!] Unkown encoder type.")
|
||||
|
||||
|
@ -143,8 +132,8 @@ class Encoder(nn.Module):
|
|||
self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
# duration predictor
|
||||
self.duration_predictor = DurationPredictor(
|
||||
hidden_channels + c_in_channels, hidden_channels_dp, 3,
|
||||
dropout_p_dp)
|
||||
hidden_channels + c_in_channels, hidden_channels_dp, 3, dropout_p_dp
|
||||
)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
"""
|
||||
|
@ -159,15 +148,14 @@ class Encoder(nn.Module):
|
|||
# [B, D, T]
|
||||
x = torch.transpose(x, 1, -1)
|
||||
# compute input sequence mask
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)),
|
||||
1).to(x.dtype)
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
# prenet
|
||||
if hasattr(self, 'prenet') and self.use_prenet:
|
||||
if hasattr(self, "prenet") and self.use_prenet:
|
||||
x = self.prenet(x, x_mask)
|
||||
# encoder
|
||||
x = self.encoder(x, x_mask)
|
||||
# postnet
|
||||
if hasattr(self, 'postnet'):
|
||||
if hasattr(self, "postnet"):
|
||||
x = self.postnet(x) * x_mask
|
||||
# set duration predictor input
|
||||
if g is not None:
|
||||
|
|
|
@ -7,8 +7,7 @@ from ..generic.normalization import LayerNorm
|
|||
|
||||
|
||||
class ResidualConv1dLayerNormBlock(nn.Module):
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size,
|
||||
num_layers, dropout_p):
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, num_layers, dropout_p):
|
||||
"""Conv1d with Layer Normalization and residual connection as in GlowTTS paper.
|
||||
https://arxiv.org/pdf/1811.00002.pdf
|
||||
|
||||
|
@ -38,10 +37,10 @@ class ResidualConv1dLayerNormBlock(nn.Module):
|
|||
|
||||
for idx in range(num_layers):
|
||||
self.conv_layers.append(
|
||||
nn.Conv1d(in_channels if idx == 0 else hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2))
|
||||
nn.Conv1d(
|
||||
in_channels if idx == 0 else hidden_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
)
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
|
@ -72,6 +71,7 @@ class InvConvNear(nn.Module):
|
|||
perform 1x1 convolution separately. Cast 1x1 conv operation
|
||||
to 2d by reshaping the input for efficiency.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, num_splits=4, no_jacobian=False, **kwargs): # pylint: disable=unused-argument
|
||||
super().__init__()
|
||||
assert num_splits % 2 == 0
|
||||
|
@ -80,8 +80,7 @@ class InvConvNear(nn.Module):
|
|||
self.no_jacobian = no_jacobian
|
||||
self.weight_inv = None
|
||||
|
||||
w_init = torch.qr(
|
||||
torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0]
|
||||
w_init = torch.qr(torch.FloatTensor(self.num_splits, self.num_splits).normal_())[0]
|
||||
if torch.det(w_init) < 0:
|
||||
w_init[:, 0] = -1 * w_init[:, 0]
|
||||
self.weight = nn.Parameter(w_init)
|
||||
|
@ -97,28 +96,25 @@ class InvConvNear(nn.Module):
|
|||
assert c % self.num_splits == 0
|
||||
if x_mask is None:
|
||||
x_mask = 1
|
||||
x_len = torch.ones((b, ), dtype=x.dtype, device=x.device) * t
|
||||
x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
|
||||
else:
|
||||
x_len = torch.sum(x_mask, [1, 2])
|
||||
|
||||
x = x.view(b, 2, c // self.num_splits, self.num_splits // 2, t)
|
||||
x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.num_splits,
|
||||
c // self.num_splits, t)
|
||||
x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.num_splits, c // self.num_splits, t)
|
||||
|
||||
if reverse:
|
||||
if self.weight_inv is not None:
|
||||
weight = self.weight_inv
|
||||
else:
|
||||
weight = torch.inverse(
|
||||
self.weight.float()).to(dtype=self.weight.dtype)
|
||||
weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
|
||||
logdet = None
|
||||
else:
|
||||
weight = self.weight
|
||||
if self.no_jacobian:
|
||||
logdet = 0
|
||||
else:
|
||||
logdet = torch.logdet(
|
||||
self.weight) * (c / self.num_splits) * x_len # [b]
|
||||
logdet = torch.logdet(self.weight) * (c / self.num_splits) * x_len # [b]
|
||||
|
||||
weight = weight.view(self.num_splits, self.num_splits, 1, 1)
|
||||
z = F.conv2d(x, weight)
|
||||
|
@ -128,40 +124,42 @@ class InvConvNear(nn.Module):
|
|||
return z, logdet
|
||||
|
||||
def store_inverse(self):
|
||||
weight_inv = torch.inverse(
|
||||
self.weight.float()).to(dtype=self.weight.dtype)
|
||||
weight_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
|
||||
self.weight_inv = nn.Parameter(weight_inv, requires_grad=False)
|
||||
|
||||
|
||||
class CouplingBlock(nn.Module):
|
||||
"""Glow Affine Coupling block as in GlowTTS paper.
|
||||
https://arxiv.org/pdf/1811.00002.pdf
|
||||
https://arxiv.org/pdf/1811.00002.pdf
|
||||
|
||||
x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o
|
||||
'-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^
|
||||
x --> x0 -> conv1d -> wavenet -> conv1d --> t, s -> concat(s*x1 + t, x0) -> o
|
||||
'-> x1 - - - - - - - - - - - - - - - - - - - - - - - - - ^
|
||||
|
||||
Args:
|
||||
in_channels (int): number of input tensor channels.
|
||||
hidden_channels (int): number of hidden channels.
|
||||
kernel_size (int): WaveNet filter kernel size.
|
||||
dilation_rate (int): rate to increase dilation by each layer in a decoder block.
|
||||
num_layers (int): number of WaveNet layers.
|
||||
c_in_channels (int): number of conditioning input channels.
|
||||
dropout_p (int): wavenet dropout rate.
|
||||
sigmoid_scale (bool): enable/disable sigmoid scaling for output scale.
|
||||
Args:
|
||||
in_channels (int): number of input tensor channels.
|
||||
hidden_channels (int): number of hidden channels.
|
||||
kernel_size (int): WaveNet filter kernel size.
|
||||
dilation_rate (int): rate to increase dilation by each layer in a decoder block.
|
||||
num_layers (int): number of WaveNet layers.
|
||||
c_in_channels (int): number of conditioning input channels.
|
||||
dropout_p (int): wavenet dropout rate.
|
||||
sigmoid_scale (bool): enable/disable sigmoid scaling for output scale.
|
||||
|
||||
Note:
|
||||
It does not use conditional inputs differently from WaveGlow.
|
||||
Note:
|
||||
It does not use conditional inputs differently from WaveGlow.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
num_layers,
|
||||
c_in_channels=0,
|
||||
dropout_p=0,
|
||||
sigmoid_scale=False):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dilation_rate,
|
||||
num_layers,
|
||||
c_in_channels=0,
|
||||
dropout_p=0,
|
||||
sigmoid_scale=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
|
@ -183,8 +181,7 @@ class CouplingBlock(nn.Module):
|
|||
end.bias.data.zero_()
|
||||
self.end = end
|
||||
# coupling layers
|
||||
self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate,
|
||||
num_layers, c_in_channels, dropout_p)
|
||||
self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels, dropout_p)
|
||||
|
||||
def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument
|
||||
"""
|
||||
|
@ -195,15 +192,15 @@ class CouplingBlock(nn.Module):
|
|||
"""
|
||||
if x_mask is None:
|
||||
x_mask = 1
|
||||
x_0, x_1 = x[:, :self.in_channels // 2], x[:, self.in_channels // 2:]
|
||||
x_0, x_1 = x[:, : self.in_channels // 2], x[:, self.in_channels // 2 :]
|
||||
|
||||
x = self.start(x_0) * x_mask
|
||||
x = self.wn(x, x_mask, g)
|
||||
out = self.end(x)
|
||||
|
||||
z_0 = x_0
|
||||
t = out[:, :self.in_channels // 2, :]
|
||||
s = out[:, self.in_channels // 2:, :]
|
||||
t = out[:, : self.in_channels // 2, :]
|
||||
s = out[:, self.in_channels // 2 :, :]
|
||||
if self.sigmoid_scale:
|
||||
s = torch.log(1e-6 + torch.sigmoid(s + 2))
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ from TTS.tts.utils.generic_utils import sequence_mask
|
|||
try:
|
||||
# TODO: fix pypi cython installation problem.
|
||||
from TTS.tts.layers.glow_tts.monotonic_align.core import maximum_path_c
|
||||
|
||||
CYTHON = True
|
||||
except ModuleNotFoundError:
|
||||
CYTHON = False
|
||||
|
@ -30,8 +31,7 @@ def generate_path(duration, mask):
|
|||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]
|
||||
]))[:, :-1]
|
||||
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||
path = path * mask
|
||||
return path
|
||||
|
||||
|
@ -43,7 +43,7 @@ def maximum_path(value, mask):
|
|||
|
||||
|
||||
def maximum_path_cython(value, mask):
|
||||
""" Cython optimised version.
|
||||
"""Cython optimised version.
|
||||
value: [b, t_x, t_y]
|
||||
mask: [b, t_x, t_y]
|
||||
"""
|
||||
|
|
|
@ -48,16 +48,19 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
proximal_init (bool, optional): enable/disable poximal init as in the paper.
|
||||
Init key and query layer weights the same. Defaults to False.
|
||||
"""
|
||||
def __init__(self,
|
||||
channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_attn_window_size=None,
|
||||
heads_share=True,
|
||||
dropout_p=0.,
|
||||
input_length=None,
|
||||
proximal_bias=False,
|
||||
proximal_init=False):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
out_channels,
|
||||
num_heads,
|
||||
rel_attn_window_size=None,
|
||||
heads_share=True,
|
||||
dropout_p=0.0,
|
||||
input_length=None,
|
||||
proximal_bias=False,
|
||||
proximal_init=False,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
assert channels % num_heads == 0, " [!] channels should be divisible by num_heads."
|
||||
|
@ -82,15 +85,15 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
# relative positional encoding layers
|
||||
if rel_attn_window_size is not None:
|
||||
n_heads_rel = 1 if heads_share else num_heads
|
||||
rel_stddev = self.k_channels**-0.5
|
||||
rel_stddev = self.k_channels ** -0.5
|
||||
emb_rel_k = nn.Parameter(
|
||||
torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1,
|
||||
self.k_channels) * rel_stddev)
|
||||
torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev
|
||||
)
|
||||
emb_rel_v = nn.Parameter(
|
||||
torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1,
|
||||
self.k_channels) * rel_stddev)
|
||||
self.register_parameter('emb_rel_k', emb_rel_k)
|
||||
self.register_parameter('emb_rel_v', emb_rel_v)
|
||||
torch.randn(n_heads_rel, rel_attn_window_size * 2 + 1, self.k_channels) * rel_stddev
|
||||
)
|
||||
self.register_parameter("emb_rel_k", emb_rel_k)
|
||||
self.register_parameter("emb_rel_v", emb_rel_v)
|
||||
|
||||
# init layers
|
||||
nn.init.xavier_uniform_(self.conv_q.weight)
|
||||
|
@ -112,38 +115,30 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
def attention(self, query, key, value, mask=None):
|
||||
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
||||
b, d, t_s, t_t = (*key.size(), query.size(2))
|
||||
query = query.view(b, self.num_heads, self.k_channels,
|
||||
t_t).transpose(2, 3)
|
||||
query = query.view(b, self.num_heads, self.k_channels, t_t).transpose(2, 3)
|
||||
key = key.view(b, self.num_heads, self.k_channels, t_s).transpose(2, 3)
|
||||
value = value.view(b, self.num_heads, self.k_channels,
|
||||
t_s).transpose(2, 3)
|
||||
value = value.view(b, self.num_heads, self.k_channels, t_s).transpose(2, 3)
|
||||
# compute raw attention scores
|
||||
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(
|
||||
self.k_channels)
|
||||
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
|
||||
# relative positional encoding for scores
|
||||
if self.rel_attn_window_size is not None:
|
||||
assert t_s == t_t, "Relative attention is only available for self-attention."
|
||||
# get relative key embeddings
|
||||
key_relative_embeddings = self._get_relative_embeddings(
|
||||
self.emb_rel_k, t_s)
|
||||
rel_logits = self._matmul_with_relative_keys(
|
||||
query, key_relative_embeddings)
|
||||
rel_logits = self._relative_position_to_absolute_position(
|
||||
rel_logits)
|
||||
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
||||
rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
|
||||
rel_logits = self._relative_position_to_absolute_position(rel_logits)
|
||||
scores_local = rel_logits / math.sqrt(self.k_channels)
|
||||
scores = scores + scores_local
|
||||
# proximan bias
|
||||
if self.proximal_bias:
|
||||
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
||||
scores = scores + self._attn_proximity_bias(t_s).to(
|
||||
device=scores.device, dtype=scores.dtype)
|
||||
scores = scores + self._attn_proximity_bias(t_s).to(device=scores.device, dtype=scores.dtype)
|
||||
# attention score masking
|
||||
if mask is not None:
|
||||
# add small value to prevent oor error.
|
||||
scores = scores.masked_fill(mask == 0, -1e4)
|
||||
if self.input_length is not None:
|
||||
block_mask = torch.ones_like(scores).triu(
|
||||
-1 * self.input_length).tril(self.input_length)
|
||||
block_mask = torch.ones_like(scores).triu(-1 * self.input_length).tril(self.input_length)
|
||||
scores = scores * block_mask + -1e4 * (1 - block_mask)
|
||||
# attention score normalization
|
||||
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
||||
|
@ -153,14 +148,10 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
output = torch.matmul(p_attn, value)
|
||||
# relative positional encoding for values
|
||||
if self.rel_attn_window_size is not None:
|
||||
relative_weights = self._absolute_position_to_relative_position(
|
||||
p_attn)
|
||||
value_relative_embeddings = self._get_relative_embeddings(
|
||||
self.emb_rel_v, t_s)
|
||||
output = output + self._matmul_with_relative_values(
|
||||
relative_weights, value_relative_embeddings)
|
||||
output = output.transpose(2, 3).contiguous().view(
|
||||
b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
||||
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
||||
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
||||
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
||||
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
||||
return output, p_attn
|
||||
|
||||
@staticmethod
|
||||
|
@ -195,20 +186,16 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
return logits
|
||||
|
||||
def _get_relative_embeddings(self, relative_embeddings, length):
|
||||
"""Convert embedding vestors to a tensor of embeddings
|
||||
"""
|
||||
"""Convert embedding vestors to a tensor of embeddings"""
|
||||
# Pad first before slice to avoid using cond ops.
|
||||
pad_length = max(length - (self.rel_attn_window_size + 1), 0)
|
||||
slice_start_position = max((self.rel_attn_window_size + 1) - length, 0)
|
||||
slice_end_position = slice_start_position + 2 * length - 1
|
||||
if pad_length > 0:
|
||||
padded_relative_embeddings = F.pad(
|
||||
relative_embeddings, [0, 0, pad_length, pad_length, 0, 0])
|
||||
padded_relative_embeddings = F.pad(relative_embeddings, [0, 0, pad_length, pad_length, 0, 0])
|
||||
else:
|
||||
padded_relative_embeddings = relative_embeddings
|
||||
used_relative_embeddings = padded_relative_embeddings[:,
|
||||
slice_start_position:
|
||||
slice_end_position]
|
||||
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
|
||||
return used_relative_embeddings
|
||||
|
||||
@staticmethod
|
||||
|
@ -226,8 +213,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
x_flat = x.view([batch, heads, length * 2 * length])
|
||||
x_flat = F.pad(x_flat, [0, length - 1, 0, 0, 0, 0])
|
||||
# Reshape and slice out the padded elements.
|
||||
x_final = x_flat.view([batch, heads, length + 1,
|
||||
2 * length - 1])[:, :, :length, length - 1:]
|
||||
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
|
||||
return x_final
|
||||
|
||||
@staticmethod
|
||||
|
@ -239,7 +225,7 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
batch, heads, length, _ = x.size()
|
||||
# padd along column
|
||||
x = F.pad(x, [0, length - 1, 0, 0, 0, 0, 0, 0])
|
||||
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
||||
x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)])
|
||||
# add 0's in the beginning that will skew the elements after reshape
|
||||
x_flat = F.pad(x_flat, [length, 0, 0, 0, 0, 0])
|
||||
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
||||
|
@ -267,19 +253,15 @@ class RelativePositionMultiHeadAttention(nn.Module):
|
|||
class FeedForwardNetwork(nn.Module):
|
||||
"""Feed Forward Inner layers for Transformer.
|
||||
|
||||
Args:
|
||||
in_channels (int): input tensor channels.
|
||||
out_channels (int): output tensor channels.
|
||||
hidden_channels (int): inner layers hidden channels.
|
||||
kernel_size (int): conv1d filter kernel size.
|
||||
dropout_p (float, optional): dropout rate. Defaults to 0.
|
||||
Args:
|
||||
in_channels (int): input tensor channels.
|
||||
out_channels (int): output tensor channels.
|
||||
hidden_channels (int): inner layers hidden channels.
|
||||
kernel_size (int): conv1d filter kernel size.
|
||||
dropout_p (float, optional): dropout rate. Defaults to 0.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
dropout_p=0.):
|
||||
|
||||
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dropout_p=0.0):
|
||||
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
@ -288,14 +270,8 @@ class FeedForwardNetwork(nn.Module):
|
|||
self.kernel_size = kernel_size
|
||||
self.dropout_p = dropout_p
|
||||
|
||||
self.conv_1 = nn.Conv1d(in_channels,
|
||||
hidden_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2)
|
||||
self.conv_2 = nn.Conv1d(hidden_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
padding=kernel_size // 2)
|
||||
self.conv_1 = nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.conv_2 = nn.Conv1d(hidden_channels, out_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.dropout = nn.Dropout(dropout_p)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
|
@ -308,34 +284,37 @@ class FeedForwardNetwork(nn.Module):
|
|||
|
||||
class RelativePositionTransformer(nn.Module):
|
||||
"""Transformer with Relative Potional Encoding.
|
||||
https://arxiv.org/abs/1803.02155
|
||||
https://arxiv.org/abs/1803.02155
|
||||
|
||||
Args:
|
||||
in_channels (int): number of channels of the input tensor.
|
||||
out_chanels (int): number of channels of the output tensor.
|
||||
hidden_channels (int): model hidden channels.
|
||||
hidden_channels_ffn (int): hidden channels of FeedForwardNetwork.
|
||||
num_heads (int): number of attention heads.
|
||||
num_layers (int): number of transformer layers.
|
||||
kernel_size (int, optional): kernel size of feed-forward inner layers. Defaults to 1.
|
||||
dropout_p (float, optional): dropout rate for self-attention and feed-forward inner layers_per_stack. Defaults to 0.
|
||||
rel_attn_window_size (int, optional): relation attention window size.
|
||||
If 4, for each time step next and previous 4 time steps are attended.
|
||||
If default, relative encoding is disabled and it is a regular transformer.
|
||||
Defaults to None.
|
||||
input_length (int, optional): input lenght to limit position encoding. Defaults to None.
|
||||
Args:
|
||||
in_channels (int): number of channels of the input tensor.
|
||||
out_chanels (int): number of channels of the output tensor.
|
||||
hidden_channels (int): model hidden channels.
|
||||
hidden_channels_ffn (int): hidden channels of FeedForwardNetwork.
|
||||
num_heads (int): number of attention heads.
|
||||
num_layers (int): number of transformer layers.
|
||||
kernel_size (int, optional): kernel size of feed-forward inner layers. Defaults to 1.
|
||||
dropout_p (float, optional): dropout rate for self-attention and feed-forward inner layers_per_stack. Defaults to 0.
|
||||
rel_attn_window_size (int, optional): relation attention window size.
|
||||
If 4, for each time step next and previous 4 time steps are attended.
|
||||
If default, relative encoding is disabled and it is a regular transformer.
|
||||
Defaults to None.
|
||||
input_length (int, optional): input lenght to limit position encoding. Defaults to None.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
hidden_channels_ffn,
|
||||
num_heads,
|
||||
num_layers,
|
||||
kernel_size=1,
|
||||
dropout_p=0.,
|
||||
rel_attn_window_size=None,
|
||||
input_length=None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
hidden_channels_ffn,
|
||||
num_heads,
|
||||
num_layers,
|
||||
kernel_size=1,
|
||||
dropout_p=0.0,
|
||||
rel_attn_window_size=None,
|
||||
input_length=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
self.hidden_channels_ffn = hidden_channels_ffn
|
||||
|
@ -359,7 +338,9 @@ class RelativePositionTransformer(nn.Module):
|
|||
num_heads,
|
||||
rel_attn_window_size=rel_attn_window_size,
|
||||
dropout_p=dropout_p,
|
||||
input_length=input_length))
|
||||
input_length=input_length,
|
||||
)
|
||||
)
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
|
||||
if hidden_channels != out_channels and (idx + 1) == self.num_layers:
|
||||
|
@ -368,15 +349,14 @@ class RelativePositionTransformer(nn.Module):
|
|||
self.ffn_layers.append(
|
||||
FeedForwardNetwork(
|
||||
hidden_channels,
|
||||
hidden_channels if
|
||||
(idx + 1) != self.num_layers else out_channels,
|
||||
hidden_channels if (idx + 1) != self.num_layers else out_channels,
|
||||
hidden_channels_ffn,
|
||||
kernel_size,
|
||||
dropout_p=dropout_p))
|
||||
dropout_p=dropout_p,
|
||||
)
|
||||
)
|
||||
|
||||
self.norm_layers_2.append(
|
||||
LayerNorm(hidden_channels if (
|
||||
idx + 1) != self.num_layers else out_channels))
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels if (idx + 1) != self.num_layers else out_channels))
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""
|
||||
|
@ -394,7 +374,7 @@ class RelativePositionTransformer(nn.Module):
|
|||
y = self.ffn_layers[i](x, x_mask)
|
||||
y = self.dropout(y)
|
||||
|
||||
if (i + 1) == self.num_layers and hasattr(self, 'proj'):
|
||||
if (i + 1) == self.num_layers and hasattr(self, "proj"):
|
||||
x = self.proj(x)
|
||||
|
||||
x = self.norm_layers_2[i](x + y)
|
||||
|
|
|
@ -34,26 +34,23 @@ class L1LossMasked(nn.Module):
|
|||
"""
|
||||
# mask: (batch, max_len, 1)
|
||||
target.requires_grad = False
|
||||
mask = sequence_mask(sequence_length=length,
|
||||
max_len=target.size(1)).unsqueeze(2).float()
|
||||
mask = sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
|
||||
if self.seq_len_norm:
|
||||
norm_w = mask / mask.sum(dim=1, keepdim=True)
|
||||
out_weights = norm_w.div(target.shape[0] * target.shape[2])
|
||||
mask = mask.expand_as(x)
|
||||
loss = functional.l1_loss(x * mask,
|
||||
target * mask,
|
||||
reduction='none')
|
||||
loss = functional.l1_loss(x * mask, target * mask, reduction="none")
|
||||
loss = loss.mul(out_weights.to(loss.device)).sum()
|
||||
else:
|
||||
mask = mask.expand_as(x)
|
||||
loss = functional.l1_loss(x * mask, target * mask, reduction='sum')
|
||||
loss = functional.l1_loss(x * mask, target * mask, reduction="sum")
|
||||
loss = loss / mask.sum()
|
||||
return loss
|
||||
|
||||
|
||||
class MSELossMasked(nn.Module):
|
||||
def __init__(self, seq_len_norm):
|
||||
super(MSELossMasked, self).__init__()
|
||||
super().__init__()
|
||||
self.seq_len_norm = seq_len_norm
|
||||
|
||||
def forward(self, x, target, length):
|
||||
|
@ -76,27 +73,23 @@ class MSELossMasked(nn.Module):
|
|||
"""
|
||||
# mask: (batch, max_len, 1)
|
||||
target.requires_grad = False
|
||||
mask = sequence_mask(sequence_length=length,
|
||||
max_len=target.size(1)).unsqueeze(2).float()
|
||||
mask = sequence_mask(sequence_length=length, max_len=target.size(1)).unsqueeze(2).float()
|
||||
if self.seq_len_norm:
|
||||
norm_w = mask / mask.sum(dim=1, keepdim=True)
|
||||
out_weights = norm_w.div(target.shape[0] * target.shape[2])
|
||||
mask = mask.expand_as(x)
|
||||
loss = functional.mse_loss(x * mask,
|
||||
target * mask,
|
||||
reduction='none')
|
||||
loss = functional.mse_loss(x * mask, target * mask, reduction="none")
|
||||
loss = loss.mul(out_weights.to(loss.device)).sum()
|
||||
else:
|
||||
mask = mask.expand_as(x)
|
||||
loss = functional.mse_loss(x * mask,
|
||||
target * mask,
|
||||
reduction='sum')
|
||||
loss = functional.mse_loss(x * mask, target * mask, reduction="sum")
|
||||
loss = loss / mask.sum()
|
||||
return loss
|
||||
|
||||
|
||||
class SSIMLoss(torch.nn.Module):
|
||||
"""SSIM loss as explained here https://en.wikipedia.org/wiki/Structural_similarity"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss_func = ssim
|
||||
|
@ -115,9 +108,7 @@ class SSIMLoss(torch.nn.Module):
|
|||
loss: An average loss value in range [0, 1] masked by the length.
|
||||
"""
|
||||
if length is not None:
|
||||
m = sequence_mask(sequence_length=length,
|
||||
max_len=y.size(1)).unsqueeze(2).float().to(
|
||||
y_hat.device)
|
||||
m = sequence_mask(sequence_length=length, max_len=y.size(1)).unsqueeze(2).float().to(y_hat.device)
|
||||
y_hat, y = y_hat * m, y * m
|
||||
return 1 - self.loss_func(y_hat.unsqueeze(1), y.unsqueeze(1))
|
||||
|
||||
|
@ -139,7 +130,7 @@ class AttentionEntropyLoss(nn.Module):
|
|||
|
||||
class BCELossMasked(nn.Module):
|
||||
def __init__(self, pos_weight):
|
||||
super(BCELossMasked, self).__init__()
|
||||
super().__init__()
|
||||
self.pos_weight = pos_weight
|
||||
|
||||
def forward(self, x, target, length):
|
||||
|
@ -163,25 +154,20 @@ class BCELossMasked(nn.Module):
|
|||
# mask: (batch, max_len, 1)
|
||||
target.requires_grad = False
|
||||
if length is not None:
|
||||
mask = sequence_mask(sequence_length=length,
|
||||
max_len=target.size(1)).float()
|
||||
mask = sequence_mask(sequence_length=length, max_len=target.size(1)).float()
|
||||
x = x * mask
|
||||
target = target * mask
|
||||
num_items = mask.sum()
|
||||
else:
|
||||
num_items = torch.numel(x)
|
||||
loss = functional.binary_cross_entropy_with_logits(
|
||||
x,
|
||||
target,
|
||||
pos_weight=self.pos_weight,
|
||||
reduction='sum')
|
||||
loss = functional.binary_cross_entropy_with_logits(x, target, pos_weight=self.pos_weight, reduction="sum")
|
||||
loss = loss / num_items
|
||||
return loss
|
||||
|
||||
|
||||
class DifferentailSpectralLoss(nn.Module):
|
||||
"""Differential Spectral Loss
|
||||
https://arxiv.org/ftp/arxiv/papers/1909/1909.10302.pdf"""
|
||||
https://arxiv.org/ftp/arxiv/papers/1909/1909.10302.pdf"""
|
||||
|
||||
def __init__(self, loss_func):
|
||||
super().__init__()
|
||||
|
@ -200,12 +186,12 @@ class DifferentailSpectralLoss(nn.Module):
|
|||
target_diff = target[:, 1:] - target[:, :-1]
|
||||
if length is None:
|
||||
return self.loss_func(x_diff, target_diff)
|
||||
return self.loss_func(x_diff, target_diff, length-1)
|
||||
return self.loss_func(x_diff, target_diff, length - 1)
|
||||
|
||||
|
||||
class GuidedAttentionLoss(torch.nn.Module):
|
||||
def __init__(self, sigma=0.4):
|
||||
super(GuidedAttentionLoss, self).__init__()
|
||||
super().__init__()
|
||||
self.sigma = sigma
|
||||
|
||||
def _make_ga_masks(self, ilens, olens):
|
||||
|
@ -214,8 +200,7 @@ class GuidedAttentionLoss(torch.nn.Module):
|
|||
max_olen = max(olens)
|
||||
ga_masks = torch.zeros((B, max_olen, max_ilen))
|
||||
for idx, (ilen, olen) in enumerate(zip(ilens, olens)):
|
||||
ga_masks[idx, :olen, :ilen] = self._make_ga_mask(
|
||||
ilen, olen, self.sigma)
|
||||
ga_masks[idx, :olen, :ilen] = self._make_ga_mask(ilen, olen, self.sigma)
|
||||
return ga_masks
|
||||
|
||||
def forward(self, att_ws, ilens, olens):
|
||||
|
@ -229,8 +214,7 @@ class GuidedAttentionLoss(torch.nn.Module):
|
|||
def _make_ga_mask(ilen, olen, sigma):
|
||||
grid_x, grid_y = torch.meshgrid(torch.arange(olen).to(olen), torch.arange(ilen).to(ilen))
|
||||
grid_x, grid_y = grid_x.float(), grid_y.float()
|
||||
return 1.0 - torch.exp(-(grid_y / ilen - grid_x / olen)**2 /
|
||||
(2 * (sigma**2)))
|
||||
return 1.0 - torch.exp(-((grid_y / ilen - grid_x / olen) ** 2) / (2 * (sigma ** 2)))
|
||||
|
||||
@staticmethod
|
||||
def _make_masks(ilens, olens):
|
||||
|
@ -249,18 +233,19 @@ class Huber(nn.Module):
|
|||
length: B
|
||||
"""
|
||||
mask = sequence_mask(sequence_length=length, max_len=y.size(1)).float()
|
||||
return torch.nn.functional.smooth_l1_loss(
|
||||
x * mask, y * mask, reduction='sum') / mask.sum()
|
||||
return torch.nn.functional.smooth_l1_loss(x * mask, y * mask, reduction="sum") / mask.sum()
|
||||
|
||||
|
||||
########################
|
||||
# MODEL LOSS LAYERS
|
||||
########################
|
||||
|
||||
|
||||
class TacotronLoss(torch.nn.Module):
|
||||
"""Collection of Tacotron set-up based on provided config."""
|
||||
|
||||
def __init__(self, c, stopnet_pos_weight=10, ga_sigma=0.4):
|
||||
super(TacotronLoss, self).__init__()
|
||||
super().__init__()
|
||||
self.stopnet_pos_weight = stopnet_pos_weight
|
||||
self.ga_alpha = c.ga_alpha
|
||||
self.decoder_diff_spec_alpha = c.decoder_diff_spec_alpha
|
||||
|
@ -273,12 +258,9 @@ class TacotronLoss(torch.nn.Module):
|
|||
|
||||
# postnet and decoder loss
|
||||
if c.loss_masking:
|
||||
self.criterion = L1LossMasked(c.seq_len_norm) if c.model in [
|
||||
"Tacotron"
|
||||
] else MSELossMasked(c.seq_len_norm)
|
||||
self.criterion = L1LossMasked(c.seq_len_norm) if c.model in ["Tacotron"] else MSELossMasked(c.seq_len_norm)
|
||||
else:
|
||||
self.criterion = nn.L1Loss() if c.model in ["Tacotron"
|
||||
] else nn.MSELoss()
|
||||
self.criterion = nn.L1Loss() if c.model in ["Tacotron"] else nn.MSELoss()
|
||||
# guided attention loss
|
||||
if c.ga_alpha > 0:
|
||||
self.criterion_ga = GuidedAttentionLoss(sigma=ga_sigma)
|
||||
|
@ -290,13 +272,23 @@ class TacotronLoss(torch.nn.Module):
|
|||
self.criterion_ssim = SSIMLoss()
|
||||
# stopnet loss
|
||||
# pylint: disable=not-callable
|
||||
self.criterion_st = BCELossMasked(
|
||||
pos_weight=torch.tensor(stopnet_pos_weight)) if c.stopnet else None
|
||||
|
||||
def forward(self, postnet_output, decoder_output, mel_input, linear_input,
|
||||
stopnet_output, stopnet_target, output_lens, decoder_b_output,
|
||||
alignments, alignment_lens, alignments_backwards, input_lens):
|
||||
self.criterion_st = BCELossMasked(pos_weight=torch.tensor(stopnet_pos_weight)) if c.stopnet else None
|
||||
|
||||
def forward(
|
||||
self,
|
||||
postnet_output,
|
||||
decoder_output,
|
||||
mel_input,
|
||||
linear_input,
|
||||
stopnet_output,
|
||||
stopnet_target,
|
||||
output_lens,
|
||||
decoder_b_output,
|
||||
alignments,
|
||||
alignment_lens,
|
||||
alignments_backwards,
|
||||
input_lens,
|
||||
):
|
||||
|
||||
# decoder outputs linear or mel spectrograms for Tacotron and Tacotron2
|
||||
# the target should be set acccordingly
|
||||
|
@ -309,85 +301,80 @@ class TacotronLoss(torch.nn.Module):
|
|||
# decoder and postnet losses
|
||||
if self.config.loss_masking:
|
||||
if self.decoder_alpha > 0:
|
||||
decoder_loss = self.criterion(decoder_output, mel_input,
|
||||
output_lens)
|
||||
decoder_loss = self.criterion(decoder_output, mel_input, output_lens)
|
||||
if self.postnet_alpha > 0:
|
||||
postnet_loss = self.criterion(postnet_output, postnet_target,
|
||||
output_lens)
|
||||
postnet_loss = self.criterion(postnet_output, postnet_target, output_lens)
|
||||
else:
|
||||
if self.decoder_alpha > 0:
|
||||
decoder_loss = self.criterion(decoder_output, mel_input)
|
||||
if self.postnet_alpha > 0:
|
||||
postnet_loss = self.criterion(postnet_output, postnet_target)
|
||||
loss = self.decoder_alpha * decoder_loss + self.postnet_alpha * postnet_loss
|
||||
return_dict['decoder_loss'] = decoder_loss
|
||||
return_dict['postnet_loss'] = postnet_loss
|
||||
return_dict["decoder_loss"] = decoder_loss
|
||||
return_dict["postnet_loss"] = postnet_loss
|
||||
|
||||
# stopnet loss
|
||||
stop_loss = self.criterion_st(
|
||||
stopnet_output, stopnet_target,
|
||||
output_lens) if self.config.stopnet else torch.zeros(1)
|
||||
stop_loss = (
|
||||
self.criterion_st(stopnet_output, stopnet_target, output_lens) if self.config.stopnet else torch.zeros(1)
|
||||
)
|
||||
if not self.config.separate_stopnet and self.config.stopnet:
|
||||
loss += stop_loss
|
||||
return_dict['stopnet_loss'] = stop_loss
|
||||
return_dict["stopnet_loss"] = stop_loss
|
||||
|
||||
# backward decoder loss (if enabled)
|
||||
if self.config.bidirectional_decoder:
|
||||
if self.config.loss_masking:
|
||||
decoder_b_loss = self.criterion(
|
||||
torch.flip(decoder_b_output, dims=(1, )), mel_input,
|
||||
output_lens)
|
||||
decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1,)), mel_input, output_lens)
|
||||
else:
|
||||
decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1, )), mel_input)
|
||||
decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_b_output, dims=(1, )), decoder_output)
|
||||
decoder_b_loss = self.criterion(torch.flip(decoder_b_output, dims=(1,)), mel_input)
|
||||
decoder_c_loss = torch.nn.functional.l1_loss(torch.flip(decoder_b_output, dims=(1,)), decoder_output)
|
||||
loss += self.decoder_alpha * (decoder_b_loss + decoder_c_loss)
|
||||
return_dict['decoder_b_loss'] = decoder_b_loss
|
||||
return_dict['decoder_c_loss'] = decoder_c_loss
|
||||
return_dict["decoder_b_loss"] = decoder_b_loss
|
||||
return_dict["decoder_c_loss"] = decoder_c_loss
|
||||
|
||||
# double decoder consistency loss (if enabled)
|
||||
if self.config.double_decoder_consistency:
|
||||
if self.config.loss_masking:
|
||||
decoder_b_loss = self.criterion(decoder_b_output, mel_input,
|
||||
output_lens)
|
||||
decoder_b_loss = self.criterion(decoder_b_output, mel_input, output_lens)
|
||||
else:
|
||||
decoder_b_loss = self.criterion(decoder_b_output, mel_input)
|
||||
# decoder_c_loss = torch.nn.functional.l1_loss(decoder_b_output, decoder_output)
|
||||
attention_c_loss = torch.nn.functional.l1_loss(alignments, alignments_backwards)
|
||||
loss += self.decoder_alpha * (decoder_b_loss + attention_c_loss)
|
||||
return_dict['decoder_coarse_loss'] = decoder_b_loss
|
||||
return_dict['decoder_ddc_loss'] = attention_c_loss
|
||||
return_dict["decoder_coarse_loss"] = decoder_b_loss
|
||||
return_dict["decoder_ddc_loss"] = attention_c_loss
|
||||
|
||||
# guided attention loss (if enabled)
|
||||
if self.config.ga_alpha > 0:
|
||||
ga_loss = self.criterion_ga(alignments, input_lens, alignment_lens)
|
||||
loss += ga_loss * self.ga_alpha
|
||||
return_dict['ga_loss'] = ga_loss
|
||||
return_dict["ga_loss"] = ga_loss
|
||||
|
||||
# decoder differential spectral loss
|
||||
if self.config.decoder_diff_spec_alpha > 0:
|
||||
decoder_diff_spec_loss = self.criterion_diff_spec(decoder_output, mel_input, output_lens)
|
||||
loss += decoder_diff_spec_loss * self.decoder_diff_spec_alpha
|
||||
return_dict['decoder_diff_spec_loss'] = decoder_diff_spec_loss
|
||||
return_dict["decoder_diff_spec_loss"] = decoder_diff_spec_loss
|
||||
|
||||
# postnet differential spectral loss
|
||||
if self.config.postnet_diff_spec_alpha > 0:
|
||||
postnet_diff_spec_loss = self.criterion_diff_spec(postnet_output, postnet_target, output_lens)
|
||||
loss += postnet_diff_spec_loss * self.postnet_diff_spec_alpha
|
||||
return_dict['postnet_diff_spec_loss'] = postnet_diff_spec_loss
|
||||
return_dict["postnet_diff_spec_loss"] = postnet_diff_spec_loss
|
||||
|
||||
# decoder ssim loss
|
||||
if self.config.decoder_ssim_alpha > 0:
|
||||
decoder_ssim_loss = self.criterion_ssim(decoder_output, mel_input, output_lens)
|
||||
loss += decoder_ssim_loss * self.postnet_ssim_alpha
|
||||
return_dict['decoder_ssim_loss'] = decoder_ssim_loss
|
||||
return_dict["decoder_ssim_loss"] = decoder_ssim_loss
|
||||
|
||||
# postnet ssim loss
|
||||
if self.config.postnet_ssim_alpha > 0:
|
||||
postnet_ssim_loss = self.criterion_ssim(postnet_output, postnet_target, output_lens)
|
||||
loss += postnet_ssim_loss * self.postnet_ssim_alpha
|
||||
return_dict['postnet_ssim_loss'] = postnet_ssim_loss
|
||||
return_dict["postnet_ssim_loss"] = postnet_ssim_loss
|
||||
|
||||
return_dict['loss'] = loss
|
||||
return_dict["loss"] = loss
|
||||
|
||||
# check if any loss is NaN
|
||||
for key, loss in return_dict.items():
|
||||
|
@ -401,22 +388,18 @@ class GlowTTSLoss(torch.nn.Module):
|
|||
super().__init__()
|
||||
self.constant_factor = 0.5 * math.log(2 * math.pi)
|
||||
|
||||
def forward(self, z, means, scales, log_det, y_lengths, o_dur_log,
|
||||
o_attn_dur, x_lengths):
|
||||
def forward(self, z, means, scales, log_det, y_lengths, o_dur_log, o_attn_dur, x_lengths):
|
||||
return_dict = {}
|
||||
# flow loss - neg log likelihood
|
||||
pz = torch.sum(scales) + 0.5 * torch.sum(
|
||||
torch.exp(-2 * scales) * (z - means)**2)
|
||||
log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (
|
||||
torch.sum(y_lengths) * z.shape[1])
|
||||
pz = torch.sum(scales) + 0.5 * torch.sum(torch.exp(-2 * scales) * (z - means) ** 2)
|
||||
log_mle = self.constant_factor + (pz - torch.sum(log_det)) / (torch.sum(y_lengths) * z.shape[1])
|
||||
# duration loss - MSE
|
||||
# loss_dur = torch.sum((o_dur_log - o_attn_dur)**2) / torch.sum(x_lengths)
|
||||
# duration loss - huber loss
|
||||
loss_dur = torch.nn.functional.smooth_l1_loss(
|
||||
o_dur_log, o_attn_dur, reduction='sum') / torch.sum(x_lengths)
|
||||
return_dict['loss'] = log_mle + loss_dur
|
||||
return_dict['log_mle'] = log_mle
|
||||
return_dict['loss_dur'] = loss_dur
|
||||
loss_dur = torch.nn.functional.smooth_l1_loss(o_dur_log, o_attn_dur, reduction="sum") / torch.sum(x_lengths)
|
||||
return_dict["loss"] = log_mle + loss_dur
|
||||
return_dict["log_mle"] = log_mle
|
||||
return_dict["loss_dur"] = loss_dur
|
||||
|
||||
# check if any loss is NaN
|
||||
for key, loss in return_dict.items():
|
||||
|
@ -441,7 +424,7 @@ class SpeedySpeechLoss(nn.Module):
|
|||
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||
huber_loss = self.huber(dur_output, dur_target, input_lens)
|
||||
loss = self.l1_alpha * l1_loss + self.ssim_alpha * ssim_loss + self.huber_alpha * huber_loss
|
||||
return {'loss': loss, 'loss_l1': l1_loss, 'loss_ssim': ssim_loss, 'loss_dur': huber_loss}
|
||||
return {"loss": loss, "loss_l1": l1_loss, "loss_ssim": ssim_loss, "loss_dur": huber_loss}
|
||||
|
||||
|
||||
def mse_loss_custom(x, y):
|
||||
|
@ -452,26 +435,27 @@ def mse_loss_custom(x, y):
|
|||
|
||||
|
||||
class MDNLoss(nn.Module):
|
||||
"""Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf.
|
||||
"""
|
||||
"""Mixture of Density Network Loss as described in https://arxiv.org/pdf/2003.01950.pdf."""
|
||||
|
||||
def forward(self, logp, text_lengths, mel_lengths): # pylint: disable=no-self-use
|
||||
'''
|
||||
"""
|
||||
Shapes:
|
||||
mu: [B, D, T]
|
||||
log_sigma: [B, D, T]
|
||||
mel_spec: [B, D, T]
|
||||
'''
|
||||
"""
|
||||
B, T_seq, T_mel = logp.shape
|
||||
log_alpha = logp.new_ones(B, T_seq, T_mel)*(-1e4)
|
||||
log_alpha = logp.new_ones(B, T_seq, T_mel) * (-1e4)
|
||||
log_alpha[:, 0, 0] = logp[:, 0, 0]
|
||||
for t in range(1, T_mel):
|
||||
prev_step = torch.cat([log_alpha[:, :, t-1:t], functional.pad(log_alpha[:, :, t-1:t],
|
||||
(0, 0, 1, -1), value=-1e4)], dim=-1)
|
||||
prev_step = torch.cat(
|
||||
[log_alpha[:, :, t - 1 : t], functional.pad(log_alpha[:, :, t - 1 : t], (0, 0, 1, -1), value=-1e4)],
|
||||
dim=-1,
|
||||
)
|
||||
log_alpha[:, :, t] = torch.logsumexp(prev_step + 1e-4, dim=-1) + logp[:, :, t]
|
||||
alpha_last = log_alpha[torch.arange(B), text_lengths-1, mel_lengths-1]
|
||||
alpha_last = log_alpha[torch.arange(B), text_lengths - 1, mel_lengths - 1]
|
||||
mdn_loss = -alpha_last.mean() / T_seq
|
||||
return mdn_loss#, log_prob_matrix
|
||||
return mdn_loss # , log_prob_matrix
|
||||
|
||||
|
||||
class AlignTTSLoss(nn.Module):
|
||||
|
@ -487,6 +471,7 @@ class AlignTTSLoss(nn.Module):
|
|||
Args:
|
||||
c (dict): TTS model configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, c):
|
||||
super().__init__()
|
||||
self.mdn_loss = MDNLoss()
|
||||
|
@ -499,10 +484,10 @@ class AlignTTSLoss(nn.Module):
|
|||
self.spec_loss_alpha = c.spec_loss_alpha
|
||||
self.mdn_alpha = c.mdn_alpha
|
||||
|
||||
def forward(self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target,
|
||||
input_lens, step, phase):
|
||||
ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(
|
||||
step)
|
||||
def forward(
|
||||
self, logp, decoder_output, decoder_target, decoder_output_lens, dur_output, dur_target, input_lens, step, phase
|
||||
):
|
||||
ssim_alpha, dur_loss_alpha, spec_loss_alpha, mdn_alpha = self.set_alphas(step)
|
||||
spec_loss, ssim_loss, dur_loss, mdn_loss = 0, 0, 0, 0
|
||||
if phase == 0:
|
||||
mdn_loss = self.mdn_loss(logp, input_lens, decoder_output_lens)
|
||||
|
@ -521,11 +506,11 @@ class AlignTTSLoss(nn.Module):
|
|||
ssim_loss = self.ssim(decoder_output, decoder_target, decoder_output_lens)
|
||||
dur_loss = self.dur_loss(dur_output.unsqueeze(2), dur_target.unsqueeze(2), input_lens)
|
||||
loss = spec_loss_alpha * spec_loss + ssim_alpha * ssim_loss + dur_loss_alpha * dur_loss + mdn_alpha * mdn_loss
|
||||
return {'loss': loss, 'loss_l1': spec_loss, 'loss_ssim': ssim_loss, 'loss_dur': dur_loss, 'mdn_loss': mdn_loss}
|
||||
return {"loss": loss, "loss_l1": spec_loss, "loss_ssim": ssim_loss, "loss_dur": dur_loss, "mdn_loss": mdn_loss}
|
||||
|
||||
@staticmethod
|
||||
def _set_alpha(step, alpha_settings):
|
||||
'''Set the loss alpha wrt number of steps.
|
||||
"""Set the loss alpha wrt number of steps.
|
||||
Return the corresponding value if no schedule is set.
|
||||
|
||||
Example:
|
||||
|
@ -536,7 +521,7 @@ class AlignTTSLoss(nn.Module):
|
|||
Args:
|
||||
step (int): number of training steps.
|
||||
alpha_settings (int or list): constant alpha value or a list defining the schedule as explained above.
|
||||
'''
|
||||
"""
|
||||
return_alpha = None
|
||||
if isinstance(alpha_settings, list):
|
||||
for key, alpha in alpha_settings:
|
||||
|
@ -547,8 +532,7 @@ class AlignTTSLoss(nn.Module):
|
|||
return return_alpha
|
||||
|
||||
def set_alphas(self, step):
|
||||
'''Set the alpha values for all the loss functions
|
||||
'''
|
||||
"""Set the alpha values for all the loss functions"""
|
||||
ssim_alpha = self._set_alpha(step, self.ssim_alpha)
|
||||
dur_loss_alpha = self._set_alpha(step, self.dur_loss_alpha)
|
||||
spec_loss_alpha = self._set_alpha(step, self.spec_loss_alpha)
|
||||
|
|
|
@ -14,20 +14,18 @@ class LocationLayer(nn.Module):
|
|||
attention_n_filters (int, optional): number of filters in convolution. Defaults to 32.
|
||||
attention_kernel_size (int, optional): kernel size of convolution filter. Defaults to 31.
|
||||
"""
|
||||
def __init__(self,
|
||||
attention_dim,
|
||||
attention_n_filters=32,
|
||||
attention_kernel_size=31):
|
||||
super(LocationLayer, self).__init__()
|
||||
|
||||
def __init__(self, attention_dim, attention_n_filters=32, attention_kernel_size=31):
|
||||
super().__init__()
|
||||
self.location_conv1d = nn.Conv1d(
|
||||
in_channels=2,
|
||||
out_channels=attention_n_filters,
|
||||
kernel_size=attention_kernel_size,
|
||||
stride=1,
|
||||
padding=(attention_kernel_size - 1) // 2,
|
||||
bias=False)
|
||||
self.location_dense = Linear(
|
||||
attention_n_filters, attention_dim, bias=False, init_gain='tanh')
|
||||
bias=False,
|
||||
)
|
||||
self.location_dense = Linear(attention_n_filters, attention_dim, bias=False, init_gain="tanh")
|
||||
|
||||
def forward(self, attention_cat):
|
||||
"""
|
||||
|
@ -35,8 +33,7 @@ class LocationLayer(nn.Module):
|
|||
attention_cat: [B, 2, C]
|
||||
"""
|
||||
processed_attention = self.location_conv1d(attention_cat)
|
||||
processed_attention = self.location_dense(
|
||||
processed_attention.transpose(1, 2))
|
||||
processed_attention = self.location_dense(processed_attention.transpose(1, 2))
|
||||
return processed_attention
|
||||
|
||||
|
||||
|
@ -49,31 +46,31 @@ class GravesAttention(nn.Module):
|
|||
query_dim (int): number of channels in query tensor.
|
||||
K (int): number of Gaussian heads to be used for computing attention.
|
||||
"""
|
||||
|
||||
COEF = 0.3989422917366028 # numpy.sqrt(1/(2*numpy.pi))
|
||||
|
||||
def __init__(self, query_dim, K):
|
||||
|
||||
super(GravesAttention, self).__init__()
|
||||
super().__init__()
|
||||
self._mask_value = 1e-8
|
||||
self.K = K
|
||||
# self.attention_alignment = 0.05
|
||||
self.eps = 1e-5
|
||||
self.J = None
|
||||
self.N_a = nn.Sequential(
|
||||
nn.Linear(query_dim, query_dim, bias=True),
|
||||
nn.ReLU(),
|
||||
nn.Linear(query_dim, 3*K, bias=True))
|
||||
nn.Linear(query_dim, query_dim, bias=True), nn.ReLU(), nn.Linear(query_dim, 3 * K, bias=True)
|
||||
)
|
||||
self.attention_weights = None
|
||||
self.mu_prev = None
|
||||
self.init_layers()
|
||||
|
||||
def init_layers(self):
|
||||
torch.nn.init.constant_(self.N_a[2].bias[(2*self.K):(3*self.K)], 1.) # bias mean
|
||||
torch.nn.init.constant_(self.N_a[2].bias[self.K:(2*self.K)], 10) # bias std
|
||||
torch.nn.init.constant_(self.N_a[2].bias[(2 * self.K) : (3 * self.K)], 1.0) # bias mean
|
||||
torch.nn.init.constant_(self.N_a[2].bias[self.K : (2 * self.K)], 10) # bias std
|
||||
|
||||
def init_states(self, inputs):
|
||||
if self.J is None or inputs.shape[1]+1 > self.J.shape[-1]:
|
||||
self.J = torch.arange(0, inputs.shape[1]+2.0).to(inputs.device) + 0.5
|
||||
if self.J is None or inputs.shape[1] + 1 > self.J.shape[-1]:
|
||||
self.J = torch.arange(0, inputs.shape[1] + 2.0).to(inputs.device) + 0.5
|
||||
self.attention_weights = torch.zeros(inputs.shape[0], inputs.shape[1]).to(inputs.device)
|
||||
self.mu_prev = torch.zeros(inputs.shape[0], self.K).to(inputs.device)
|
||||
|
||||
|
@ -108,7 +105,7 @@ class GravesAttention(nn.Module):
|
|||
mu_t = self.mu_prev + torch.nn.functional.softplus(k_t)
|
||||
g_t = torch.softmax(g_t, dim=-1) + self.eps
|
||||
|
||||
j = self.J[:inputs.size(1)+1]
|
||||
j = self.J[: inputs.size(1) + 1]
|
||||
|
||||
# attention weights
|
||||
phi_t = g_t.unsqueeze(-1) * (1 / (1 + torch.sigmoid((mu_t.unsqueeze(-1) - j) / sig_t.unsqueeze(-1))))
|
||||
|
@ -164,21 +161,29 @@ class OriginalAttention(nn.Module):
|
|||
trans_agent (bool): enable/disable transition agent in the forward attention.
|
||||
forward_attn_mask (int): enable/disable an explicit masking in forward attention. It is useful to set at especially inference time.
|
||||
"""
|
||||
|
||||
# Pylint gets confused by PyTorch conventions here
|
||||
#pylint: disable=attribute-defined-outside-init
|
||||
def __init__(self, query_dim, embedding_dim, attention_dim,
|
||||
location_attention, attention_location_n_filters,
|
||||
attention_location_kernel_size, windowing, norm, forward_attn,
|
||||
trans_agent, forward_attn_mask):
|
||||
super(OriginalAttention, self).__init__()
|
||||
self.query_layer = Linear(
|
||||
query_dim, attention_dim, bias=False, init_gain='tanh')
|
||||
self.inputs_layer = Linear(
|
||||
embedding_dim, attention_dim, bias=False, init_gain='tanh')
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
def __init__(
|
||||
self,
|
||||
query_dim,
|
||||
embedding_dim,
|
||||
attention_dim,
|
||||
location_attention,
|
||||
attention_location_n_filters,
|
||||
attention_location_kernel_size,
|
||||
windowing,
|
||||
norm,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
):
|
||||
super().__init__()
|
||||
self.query_layer = Linear(query_dim, attention_dim, bias=False, init_gain="tanh")
|
||||
self.inputs_layer = Linear(embedding_dim, attention_dim, bias=False, init_gain="tanh")
|
||||
self.v = Linear(attention_dim, 1, bias=True)
|
||||
if trans_agent:
|
||||
self.ta = nn.Linear(
|
||||
query_dim + embedding_dim, 1, bias=True)
|
||||
self.ta = nn.Linear(query_dim + embedding_dim, 1, bias=True)
|
||||
if location_attention:
|
||||
self.location_layer = LocationLayer(
|
||||
attention_dim,
|
||||
|
@ -202,9 +207,7 @@ class OriginalAttention(nn.Module):
|
|||
def init_forward_attn(self, inputs):
|
||||
B = inputs.shape[0]
|
||||
T = inputs.shape[1]
|
||||
self.alpha = torch.cat(
|
||||
[torch.ones([B, 1]),
|
||||
torch.zeros([B, T])[:, :-1] + 1e-7], dim=1).to(inputs.device)
|
||||
self.alpha = torch.cat([torch.ones([B, 1]), torch.zeros([B, T])[:, :-1] + 1e-7], dim=1).to(inputs.device)
|
||||
self.u = (0.5 * torch.ones([B, 1])).to(inputs.device)
|
||||
|
||||
def init_location_attention(self, inputs):
|
||||
|
@ -230,14 +233,10 @@ class OriginalAttention(nn.Module):
|
|||
self.attention_weights_cum += alignments
|
||||
|
||||
def get_location_attention(self, query, processed_inputs):
|
||||
attention_cat = torch.cat((self.attention_weights.unsqueeze(1),
|
||||
self.attention_weights_cum.unsqueeze(1)),
|
||||
dim=1)
|
||||
attention_cat = torch.cat((self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1)
|
||||
processed_query = self.query_layer(query.unsqueeze(1))
|
||||
processed_attention_weights = self.location_layer(attention_cat)
|
||||
energies = self.v(
|
||||
torch.tanh(processed_query + processed_attention_weights +
|
||||
processed_inputs))
|
||||
energies = self.v(torch.tanh(processed_query + processed_attention_weights + processed_inputs))
|
||||
energies = energies.squeeze(-1)
|
||||
return energies, processed_query
|
||||
|
||||
|
@ -264,24 +263,17 @@ class OriginalAttention(nn.Module):
|
|||
|
||||
def apply_forward_attention(self, alignment):
|
||||
# forward attention
|
||||
fwd_shifted_alpha = F.pad(
|
||||
self.alpha[:, :-1].clone().to(alignment.device), (1, 0, 0, 0))
|
||||
fwd_shifted_alpha = F.pad(self.alpha[:, :-1].clone().to(alignment.device), (1, 0, 0, 0))
|
||||
# compute transition potentials
|
||||
alpha = ((1 - self.u) * self.alpha
|
||||
+ self.u * fwd_shifted_alpha
|
||||
+ 1e-8) * alignment
|
||||
alpha = ((1 - self.u) * self.alpha + self.u * fwd_shifted_alpha + 1e-8) * alignment
|
||||
# force incremental alignment
|
||||
if not self.training and self.forward_attn_mask:
|
||||
_, n = fwd_shifted_alpha.max(1)
|
||||
val, _ = alpha.max(1)
|
||||
for b in range(alignment.shape[0]):
|
||||
alpha[b, n[b] + 3:] = 0
|
||||
alpha[b, :(
|
||||
n[b] - 1
|
||||
)] = 0 # ignore all previous states to prevent repetition.
|
||||
alpha[b,
|
||||
(n[b] - 2
|
||||
)] = 0.01 * val[b] # smoothing factor for the prev step
|
||||
alpha[b, n[b] + 3 :] = 0
|
||||
alpha[b, : (n[b] - 1)] = 0 # ignore all previous states to prevent repetition.
|
||||
alpha[b, (n[b] - 2)] = 0.01 * val[b] # smoothing factor for the prev step
|
||||
# renormalize attention weights
|
||||
alpha = alpha / alpha.sum(dim=1, keepdim=True)
|
||||
return alpha
|
||||
|
@ -295,11 +287,9 @@ class OriginalAttention(nn.Module):
|
|||
mask: [B, T_en]
|
||||
"""
|
||||
if self.location_attention:
|
||||
attention, _ = self.get_location_attention(
|
||||
query, processed_inputs)
|
||||
attention, _ = self.get_location_attention(query, processed_inputs)
|
||||
else:
|
||||
attention, _ = self.get_attention(
|
||||
query, processed_inputs)
|
||||
attention, _ = self.get_attention(query, processed_inputs)
|
||||
# apply masking
|
||||
if mask is not None:
|
||||
attention.data.masked_fill_(~mask, self._mask_value)
|
||||
|
@ -311,9 +301,7 @@ class OriginalAttention(nn.Module):
|
|||
if self.norm == "softmax":
|
||||
alignment = torch.softmax(attention, dim=-1)
|
||||
elif self.norm == "sigmoid":
|
||||
alignment = torch.sigmoid(attention) / torch.sigmoid(
|
||||
attention).sum(
|
||||
dim=1, keepdim=True)
|
||||
alignment = torch.sigmoid(attention) / torch.sigmoid(attention).sum(dim=1, keepdim=True)
|
||||
else:
|
||||
raise ValueError("Unknown value for attention norm type")
|
||||
|
||||
|
@ -367,19 +355,20 @@ class MonotonicDynamicConvolutionAttention(nn.Module):
|
|||
alpha (float, optional): [description]. Defaults to 0.1 from the paper.
|
||||
beta (float, optional): [description]. Defaults to 0.9 from the paper.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim,
|
||||
embedding_dim, # pylint: disable=unused-argument
|
||||
attention_dim,
|
||||
static_filter_dim,
|
||||
static_kernel_size,
|
||||
dynamic_filter_dim,
|
||||
dynamic_kernel_size,
|
||||
prior_filter_len=11,
|
||||
alpha=0.1,
|
||||
beta=0.9,
|
||||
):
|
||||
self,
|
||||
query_dim,
|
||||
embedding_dim, # pylint: disable=unused-argument
|
||||
attention_dim,
|
||||
static_filter_dim,
|
||||
static_kernel_size,
|
||||
dynamic_filter_dim,
|
||||
dynamic_kernel_size,
|
||||
prior_filter_len=11,
|
||||
alpha=0.1,
|
||||
beta=0.9,
|
||||
):
|
||||
super().__init__()
|
||||
self._mask_value = 1e-8
|
||||
self.dynamic_filter_dim = dynamic_filter_dim
|
||||
|
@ -388,9 +377,7 @@ class MonotonicDynamicConvolutionAttention(nn.Module):
|
|||
self.attention_weights = None
|
||||
# setup key and query layers
|
||||
self.query_layer = nn.Linear(query_dim, attention_dim)
|
||||
self.key_layer = nn.Linear(
|
||||
attention_dim, dynamic_filter_dim * dynamic_kernel_size, bias=False
|
||||
)
|
||||
self.key_layer = nn.Linear(attention_dim, dynamic_filter_dim * dynamic_kernel_size, bias=False)
|
||||
self.static_filter_conv = nn.Conv1d(
|
||||
1,
|
||||
static_filter_dim,
|
||||
|
@ -402,8 +389,7 @@ class MonotonicDynamicConvolutionAttention(nn.Module):
|
|||
self.dynamic_filter_layer = nn.Linear(dynamic_filter_dim, attention_dim)
|
||||
self.v = nn.Linear(attention_dim, 1, bias=False)
|
||||
|
||||
prior = betabinom.pmf(range(prior_filter_len), prior_filter_len - 1,
|
||||
alpha, beta)
|
||||
prior = betabinom.pmf(range(prior_filter_len), prior_filter_len - 1, alpha, beta)
|
||||
self.register_buffer("prior", torch.FloatTensor(prior).flip(0))
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
|
@ -416,8 +402,8 @@ class MonotonicDynamicConvolutionAttention(nn.Module):
|
|||
"""
|
||||
# compute prior filters
|
||||
prior_filter = F.conv1d(
|
||||
F.pad(self.attention_weights.unsqueeze(1),
|
||||
(self.prior_filter_len - 1, 0)), self.prior.view(1, 1, -1))
|
||||
F.pad(self.attention_weights.unsqueeze(1), (self.prior_filter_len - 1, 0)), self.prior.view(1, 1, -1)
|
||||
)
|
||||
prior_filter = torch.log(prior_filter.clamp_min_(1e-6)).squeeze(1)
|
||||
G = self.key_layer(torch.tanh(self.query_layer(query)))
|
||||
# compute dynamic filters
|
||||
|
@ -430,10 +416,12 @@ class MonotonicDynamicConvolutionAttention(nn.Module):
|
|||
dynamic_filter = dynamic_filter.view(query.size(0), self.dynamic_filter_dim, -1).transpose(1, 2)
|
||||
# compute static filters
|
||||
static_filter = self.static_filter_conv(self.attention_weights.unsqueeze(1)).transpose(1, 2)
|
||||
alignment = self.v(
|
||||
torch.tanh(
|
||||
self.static_filter_layer(static_filter) +
|
||||
self.dynamic_filter_layer(dynamic_filter))).squeeze(-1) + prior_filter
|
||||
alignment = (
|
||||
self.v(
|
||||
torch.tanh(self.static_filter_layer(static_filter) + self.dynamic_filter_layer(dynamic_filter))
|
||||
).squeeze(-1)
|
||||
+ prior_filter
|
||||
)
|
||||
# compute attention weights
|
||||
attention_weights = F.softmax(alignment, dim=-1)
|
||||
# apply masking
|
||||
|
@ -451,33 +439,52 @@ class MonotonicDynamicConvolutionAttention(nn.Module):
|
|||
B = inputs.size(0)
|
||||
T = inputs.size(1)
|
||||
self.attention_weights = torch.zeros([B, T], device=inputs.device)
|
||||
self.attention_weights[:, 0] = 1.
|
||||
self.attention_weights[:, 0] = 1.0
|
||||
|
||||
|
||||
def init_attn(attn_type, query_dim, embedding_dim, attention_dim,
|
||||
location_attention, attention_location_n_filters,
|
||||
attention_location_kernel_size, windowing, norm, forward_attn,
|
||||
trans_agent, forward_attn_mask, attn_K):
|
||||
def init_attn(
|
||||
attn_type,
|
||||
query_dim,
|
||||
embedding_dim,
|
||||
attention_dim,
|
||||
location_attention,
|
||||
attention_location_n_filters,
|
||||
attention_location_kernel_size,
|
||||
windowing,
|
||||
norm,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
attn_K,
|
||||
):
|
||||
if attn_type == "original":
|
||||
return OriginalAttention(query_dim, embedding_dim, attention_dim,
|
||||
location_attention,
|
||||
attention_location_n_filters,
|
||||
attention_location_kernel_size, windowing,
|
||||
norm, forward_attn, trans_agent,
|
||||
forward_attn_mask)
|
||||
return OriginalAttention(
|
||||
query_dim,
|
||||
embedding_dim,
|
||||
attention_dim,
|
||||
location_attention,
|
||||
attention_location_n_filters,
|
||||
attention_location_kernel_size,
|
||||
windowing,
|
||||
norm,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
)
|
||||
if attn_type == "graves":
|
||||
return GravesAttention(query_dim, attn_K)
|
||||
if attn_type == "dynamic_convolution":
|
||||
return MonotonicDynamicConvolutionAttention(query_dim,
|
||||
embedding_dim,
|
||||
attention_dim,
|
||||
static_filter_dim=8,
|
||||
static_kernel_size=21,
|
||||
dynamic_filter_dim=8,
|
||||
dynamic_kernel_size=21,
|
||||
prior_filter_len=11,
|
||||
alpha=0.1,
|
||||
beta=0.9)
|
||||
return MonotonicDynamicConvolutionAttention(
|
||||
query_dim,
|
||||
embedding_dim,
|
||||
attention_dim,
|
||||
static_filter_dim=8,
|
||||
static_kernel_size=21,
|
||||
dynamic_filter_dim=8,
|
||||
dynamic_kernel_size=21,
|
||||
prior_filter_len=11,
|
||||
alpha=0.1,
|
||||
beta=0.9,
|
||||
)
|
||||
|
||||
raise RuntimeError(
|
||||
" [!] Given Attention Type '{attn_type}' is not exist.")
|
||||
raise RuntimeError(" [!] Given Attention Type '{attn_type}' is not exist.")
|
||||
|
|
|
@ -12,20 +12,14 @@ class Linear(nn.Module):
|
|||
bias (bool, optional): enable/disable bias in the layer. Defaults to True.
|
||||
init_gain (str, optional): method to compute the gain in the weight initializtion based on the nonlinear activation used afterwards. Defaults to 'linear'.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_features,
|
||||
out_features,
|
||||
bias=True,
|
||||
init_gain='linear'):
|
||||
super(Linear, self).__init__()
|
||||
self.linear_layer = torch.nn.Linear(
|
||||
in_features, out_features, bias=bias)
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True, init_gain="linear"):
|
||||
super().__init__()
|
||||
self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias)
|
||||
self._init_w(init_gain)
|
||||
|
||||
def _init_w(self, init_gain):
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear_layer.weight,
|
||||
gain=torch.nn.init.calculate_gain(init_gain))
|
||||
torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain))
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_layer(x)
|
||||
|
@ -42,21 +36,15 @@ class LinearBN(nn.Module):
|
|||
bias (bool, optional): enable/disable bias in the linear layer. Defaults to True.
|
||||
init_gain (str, optional): method to set the gain for weight initialization. Defaults to 'linear'.
|
||||
"""
|
||||
def __init__(self,
|
||||
in_features,
|
||||
out_features,
|
||||
bias=True,
|
||||
init_gain='linear'):
|
||||
super(LinearBN, self).__init__()
|
||||
self.linear_layer = torch.nn.Linear(
|
||||
in_features, out_features, bias=bias)
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True, init_gain="linear"):
|
||||
super().__init__()
|
||||
self.linear_layer = torch.nn.Linear(in_features, out_features, bias=bias)
|
||||
self.batch_normalization = nn.BatchNorm1d(out_features, momentum=0.1, eps=1e-5)
|
||||
self._init_w(init_gain)
|
||||
|
||||
def _init_w(self, init_gain):
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear_layer.weight,
|
||||
gain=torch.nn.init.calculate_gain(init_gain))
|
||||
torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(init_gain))
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
|
@ -96,27 +84,21 @@ class Prenet(nn.Module):
|
|||
Defaults to [256, 256].
|
||||
bias (bool, optional): enable/disable bias in prenet linear layers. Defaults to True.
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(self,
|
||||
in_features,
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
out_features=[256, 256],
|
||||
bias=True):
|
||||
super(Prenet, self).__init__()
|
||||
def __init__(self, in_features, prenet_type="original", prenet_dropout=True, out_features=[256, 256], bias=True):
|
||||
super().__init__()
|
||||
self.prenet_type = prenet_type
|
||||
self.prenet_dropout = prenet_dropout
|
||||
in_features = [in_features] + out_features[:-1]
|
||||
if prenet_type == "bn":
|
||||
self.linear_layers = nn.ModuleList([
|
||||
LinearBN(in_size, out_size, bias=bias)
|
||||
for (in_size, out_size) in zip(in_features, out_features)
|
||||
])
|
||||
self.linear_layers = nn.ModuleList(
|
||||
[LinearBN(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)]
|
||||
)
|
||||
elif prenet_type == "original":
|
||||
self.linear_layers = nn.ModuleList([
|
||||
Linear(in_size, out_size, bias=bias)
|
||||
for (in_size, out_size) in zip(in_features, out_features)
|
||||
])
|
||||
self.linear_layers = nn.ModuleList(
|
||||
[Linear(in_size, out_size, bias=bias) for (in_size, out_size) in zip(in_features, out_features)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for linear in self.linear_layers:
|
||||
|
|
|
@ -11,8 +11,7 @@ class GST(nn.Module):
|
|||
def __init__(self, num_mel, num_heads, num_style_tokens, gst_embedding_dim, speaker_embedding_dim=None):
|
||||
super().__init__()
|
||||
self.encoder = ReferenceEncoder(num_mel, gst_embedding_dim)
|
||||
self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens,
|
||||
gst_embedding_dim, speaker_embedding_dim)
|
||||
self.style_token_layer = StyleTokenLayer(num_heads, num_style_tokens, gst_embedding_dim, speaker_embedding_dim)
|
||||
|
||||
def forward(self, inputs, speaker_embedding=None):
|
||||
enc_out = self.encoder(inputs)
|
||||
|
@ -39,24 +38,17 @@ class ReferenceEncoder(nn.Module):
|
|||
num_layers = len(filters) - 1
|
||||
convs = [
|
||||
nn.Conv2d(
|
||||
in_channels=filters[i],
|
||||
out_channels=filters[i + 1],
|
||||
kernel_size=(3, 3),
|
||||
stride=(2, 2),
|
||||
padding=(1, 1)) for i in range(num_layers)
|
||||
in_channels=filters[i], out_channels=filters[i + 1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList([
|
||||
nn.BatchNorm2d(num_features=filter_size)
|
||||
for filter_size in filters[1:]
|
||||
])
|
||||
self.bns = nn.ModuleList([nn.BatchNorm2d(num_features=filter_size) for filter_size in filters[1:]])
|
||||
|
||||
post_conv_height = self.calculate_post_conv_height(
|
||||
num_mel, 3, 2, 1, num_layers)
|
||||
post_conv_height = self.calculate_post_conv_height(num_mel, 3, 2, 1, num_layers)
|
||||
self.recurrence = nn.GRU(
|
||||
input_size=filters[-1] * post_conv_height,
|
||||
hidden_size=embedding_dim // 2,
|
||||
batch_first=True)
|
||||
input_size=filters[-1] * post_conv_height, hidden_size=embedding_dim // 2, batch_first=True
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
batch_size = inputs.size(0)
|
||||
|
@ -81,8 +73,7 @@ class ReferenceEncoder(nn.Module):
|
|||
return out.squeeze(0)
|
||||
|
||||
@staticmethod
|
||||
def calculate_post_conv_height(height, kernel_size, stride, pad,
|
||||
n_convs):
|
||||
def calculate_post_conv_height(height, kernel_size, stride, pad, n_convs):
|
||||
"""Height of spec after n convolutions with fixed kernel/stride/pad."""
|
||||
for _ in range(n_convs):
|
||||
height = (height - kernel_size + 2 * pad) // stride + 1
|
||||
|
@ -92,8 +83,7 @@ class ReferenceEncoder(nn.Module):
|
|||
class StyleTokenLayer(nn.Module):
|
||||
"""NN Module attending to style tokens based on prosody encodings."""
|
||||
|
||||
def __init__(self, num_heads, num_style_tokens,
|
||||
embedding_dim, speaker_embedding_dim=None):
|
||||
def __init__(self, num_heads, num_style_tokens, embedding_dim, speaker_embedding_dim=None):
|
||||
super().__init__()
|
||||
|
||||
self.query_dim = embedding_dim // 2
|
||||
|
@ -102,35 +92,31 @@ class StyleTokenLayer(nn.Module):
|
|||
self.query_dim += speaker_embedding_dim
|
||||
|
||||
self.key_dim = embedding_dim // num_heads
|
||||
self.style_tokens = nn.Parameter(
|
||||
torch.FloatTensor(num_style_tokens, self.key_dim))
|
||||
self.style_tokens = nn.Parameter(torch.FloatTensor(num_style_tokens, self.key_dim))
|
||||
nn.init.normal_(self.style_tokens, mean=0, std=0.5)
|
||||
self.attention = MultiHeadAttention(
|
||||
query_dim=self.query_dim,
|
||||
key_dim=self.key_dim,
|
||||
num_units=embedding_dim,
|
||||
num_heads=num_heads)
|
||||
query_dim=self.query_dim, key_dim=self.key_dim, num_units=embedding_dim, num_heads=num_heads
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
batch_size = inputs.size(0)
|
||||
prosody_encoding = inputs.unsqueeze(1)
|
||||
# prosody_encoding: 3D tensor [batch_size, 1, encoding_size==128]
|
||||
tokens = torch.tanh(self.style_tokens) \
|
||||
.unsqueeze(0) \
|
||||
.expand(batch_size, -1, -1)
|
||||
tokens = torch.tanh(self.style_tokens).unsqueeze(0).expand(batch_size, -1, -1)
|
||||
# tokens: 3D tensor [batch_size, num tokens, token embedding size]
|
||||
style_embed = self.attention(prosody_encoding, tokens)
|
||||
|
||||
return style_embed
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
'''
|
||||
"""
|
||||
input:
|
||||
query --- [N, T_q, query_dim]
|
||||
key --- [N, T_k, key_dim]
|
||||
output:
|
||||
out --- [N, T_q, num_units]
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, query_dim, key_dim, num_units, num_heads):
|
||||
|
||||
|
@ -139,12 +125,9 @@ class MultiHeadAttention(nn.Module):
|
|||
self.num_heads = num_heads
|
||||
self.key_dim = key_dim
|
||||
|
||||
self.W_query = nn.Linear(
|
||||
in_features=query_dim, out_features=num_units, bias=False)
|
||||
self.W_key = nn.Linear(
|
||||
in_features=key_dim, out_features=num_units, bias=False)
|
||||
self.W_value = nn.Linear(
|
||||
in_features=key_dim, out_features=num_units, bias=False)
|
||||
self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
|
||||
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||
self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
||||
|
||||
def forward(self, query, key):
|
||||
queries = self.W_query(query) # [N, T_q, num_units]
|
||||
|
@ -152,25 +135,17 @@ class MultiHeadAttention(nn.Module):
|
|||
values = self.W_value(key)
|
||||
|
||||
split_size = self.num_units // self.num_heads
|
||||
queries = torch.stack(
|
||||
torch.split(queries, split_size, dim=2),
|
||||
dim=0) # [h, N, T_q, num_units/h]
|
||||
keys = torch.stack(
|
||||
torch.split(keys, split_size, dim=2),
|
||||
dim=0) # [h, N, T_k, num_units/h]
|
||||
values = torch.stack(
|
||||
torch.split(values, split_size, dim=2),
|
||||
dim=0) # [h, N, T_k, num_units/h]
|
||||
queries = torch.stack(torch.split(queries, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h]
|
||||
keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||
values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
|
||||
|
||||
# score = softmax(QK^T / (d_k ** 0.5))
|
||||
scores = torch.matmul(queries, keys.transpose(2, 3)) # [h, N, T_q, T_k]
|
||||
scores = scores / (self.key_dim**0.5)
|
||||
scores = scores / (self.key_dim ** 0.5)
|
||||
scores = F.softmax(scores, dim=3)
|
||||
|
||||
# out = score * V
|
||||
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
|
||||
out = torch.cat(
|
||||
torch.split(out, 1, dim=0),
|
||||
dim=3).squeeze(0) # [N, T_q, num_units]
|
||||
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
|
||||
|
||||
return out
|
||||
|
|
|
@ -23,24 +23,14 @@ class BatchNormConv1d(nn.Module):
|
|||
- output: (B, D)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
activation=None):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, activation=None):
|
||||
|
||||
super(BatchNormConv1d, self).__init__()
|
||||
super().__init__()
|
||||
self.padding = padding
|
||||
self.padder = nn.ConstantPad1d(padding, 0)
|
||||
self.conv1d = nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
bias=False)
|
||||
in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0, bias=False
|
||||
)
|
||||
# Following tensorflow's default parameters
|
||||
self.bn = nn.BatchNorm1d(out_channels, momentum=0.99, eps=1e-3)
|
||||
self.activation = activation
|
||||
|
@ -48,15 +38,14 @@ class BatchNormConv1d(nn.Module):
|
|||
|
||||
def init_layers(self):
|
||||
if isinstance(self.activation, torch.nn.ReLU):
|
||||
w_gain = 'relu'
|
||||
w_gain = "relu"
|
||||
elif isinstance(self.activation, torch.nn.Tanh):
|
||||
w_gain = 'tanh'
|
||||
w_gain = "tanh"
|
||||
elif self.activation is None:
|
||||
w_gain = 'linear'
|
||||
w_gain = "linear"
|
||||
else:
|
||||
raise RuntimeError('Unknown activation function')
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.conv1d.weight, gain=torch.nn.init.calculate_gain(w_gain))
|
||||
raise RuntimeError("Unknown activation function")
|
||||
torch.nn.init.xavier_uniform_(self.conv1d.weight, gain=torch.nn.init.calculate_gain(w_gain))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.padder(x)
|
||||
|
@ -81,7 +70,7 @@ class Highway(nn.Module):
|
|||
|
||||
# TODO: Try GLU layer
|
||||
def __init__(self, in_features, out_feature):
|
||||
super(Highway, self).__init__()
|
||||
super().__init__()
|
||||
self.H = nn.Linear(in_features, out_feature)
|
||||
self.H.bias.data.zero_()
|
||||
self.T = nn.Linear(in_features, out_feature)
|
||||
|
@ -91,10 +80,8 @@ class Highway(nn.Module):
|
|||
# self.init_layers()
|
||||
|
||||
def init_layers(self):
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.H.weight, gain=torch.nn.init.calculate_gain('relu'))
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.T.weight, gain=torch.nn.init.calculate_gain('sigmoid'))
|
||||
torch.nn.init.xavier_uniform_(self.H.weight, gain=torch.nn.init.calculate_gain("relu"))
|
||||
torch.nn.init.xavier_uniform_(self.T.weight, gain=torch.nn.init.calculate_gain("sigmoid"))
|
||||
|
||||
def forward(self, inputs):
|
||||
H = self.relu(self.H(inputs))
|
||||
|
@ -104,30 +91,33 @@ class Highway(nn.Module):
|
|||
|
||||
class CBHG(nn.Module):
|
||||
"""CBHG module: a recurrent neural network composed of:
|
||||
- 1-d convolution banks
|
||||
- Highway networks + residual connections
|
||||
- Bidirectional gated recurrent units
|
||||
- 1-d convolution banks
|
||||
- Highway networks + residual connections
|
||||
- Bidirectional gated recurrent units
|
||||
|
||||
Args:
|
||||
in_features (int): sample size
|
||||
K (int): max filter size in conv bank
|
||||
projections (list): conv channel sizes for conv projections
|
||||
num_highways (int): number of highways layers
|
||||
Args:
|
||||
in_features (int): sample size
|
||||
K (int): max filter size in conv bank
|
||||
projections (list): conv channel sizes for conv projections
|
||||
num_highways (int): number of highways layers
|
||||
|
||||
Shapes:
|
||||
- input: (B, C, T_in)
|
||||
- output: (B, T_in, C*2)
|
||||
Shapes:
|
||||
- input: (B, C, T_in)
|
||||
- output: (B, T_in, C*2)
|
||||
"""
|
||||
#pylint: disable=dangerous-default-value
|
||||
def __init__(self,
|
||||
in_features,
|
||||
K=16,
|
||||
conv_bank_features=128,
|
||||
conv_projections=[128, 128],
|
||||
highway_features=128,
|
||||
gru_features=128,
|
||||
num_highways=4):
|
||||
super(CBHG, self).__init__()
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
K=16,
|
||||
conv_bank_features=128,
|
||||
conv_projections=[128, 128],
|
||||
highway_features=128,
|
||||
gru_features=128,
|
||||
num_highways=4,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.conv_bank_features = conv_bank_features
|
||||
self.highway_features = highway_features
|
||||
|
@ -136,14 +126,19 @@ class CBHG(nn.Module):
|
|||
self.relu = nn.ReLU()
|
||||
# list of conv1d bank with filter size k=1...K
|
||||
# TODO: try dilational layers instead
|
||||
self.conv1d_banks = nn.ModuleList([
|
||||
BatchNormConv1d(in_features,
|
||||
conv_bank_features,
|
||||
kernel_size=k,
|
||||
stride=1,
|
||||
padding=[(k - 1) // 2, k // 2],
|
||||
activation=self.relu) for k in range(1, K + 1)
|
||||
])
|
||||
self.conv1d_banks = nn.ModuleList(
|
||||
[
|
||||
BatchNormConv1d(
|
||||
in_features,
|
||||
conv_bank_features,
|
||||
kernel_size=k,
|
||||
stride=1,
|
||||
padding=[(k - 1) // 2, k // 2],
|
||||
activation=self.relu,
|
||||
)
|
||||
for k in range(1, K + 1)
|
||||
]
|
||||
)
|
||||
# max pooling of conv bank, with padding
|
||||
# TODO: try average pooling OR larger kernel size
|
||||
out_features = [K * conv_bank_features] + conv_projections[:-1]
|
||||
|
@ -151,31 +146,16 @@ class CBHG(nn.Module):
|
|||
activations += [None]
|
||||
# setup conv1d projection layers
|
||||
layer_set = []
|
||||
for (in_size, out_size, ac) in zip(out_features, conv_projections,
|
||||
activations):
|
||||
layer = BatchNormConv1d(in_size,
|
||||
out_size,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=[1, 1],
|
||||
activation=ac)
|
||||
for (in_size, out_size, ac) in zip(out_features, conv_projections, activations):
|
||||
layer = BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1, padding=[1, 1], activation=ac)
|
||||
layer_set.append(layer)
|
||||
self.conv1d_projections = nn.ModuleList(layer_set)
|
||||
# setup Highway layers
|
||||
if self.highway_features != conv_projections[-1]:
|
||||
self.pre_highway = nn.Linear(conv_projections[-1],
|
||||
highway_features,
|
||||
bias=False)
|
||||
self.highways = nn.ModuleList([
|
||||
Highway(highway_features, highway_features)
|
||||
for _ in range(num_highways)
|
||||
])
|
||||
self.pre_highway = nn.Linear(conv_projections[-1], highway_features, bias=False)
|
||||
self.highways = nn.ModuleList([Highway(highway_features, highway_features) for _ in range(num_highways)])
|
||||
# bi-directional GPU layer
|
||||
self.gru = nn.GRU(gru_features,
|
||||
gru_features,
|
||||
1,
|
||||
batch_first=True,
|
||||
bidirectional=True)
|
||||
self.gru = nn.GRU(gru_features, gru_features, 1, batch_first=True, bidirectional=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
# (B, in_features, T_in)
|
||||
|
@ -210,7 +190,7 @@ class EncoderCBHG(nn.Module):
|
|||
r"""CBHG module with Encoder specific arguments"""
|
||||
|
||||
def __init__(self):
|
||||
super(EncoderCBHG, self).__init__()
|
||||
super().__init__()
|
||||
self.cbhg = CBHG(
|
||||
128,
|
||||
K=16,
|
||||
|
@ -218,7 +198,8 @@ class EncoderCBHG(nn.Module):
|
|||
conv_projections=[128, 128],
|
||||
highway_features=128,
|
||||
gru_features=128,
|
||||
num_highways=4)
|
||||
num_highways=4,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.cbhg(x)
|
||||
|
@ -235,7 +216,7 @@ class Encoder(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, in_features):
|
||||
super(Encoder, self).__init__()
|
||||
super().__init__()
|
||||
self.prenet = Prenet(in_features, out_features=[256, 128])
|
||||
self.cbhg = EncoderCBHG()
|
||||
|
||||
|
@ -248,7 +229,7 @@ class Encoder(nn.Module):
|
|||
|
||||
class PostCBHG(nn.Module):
|
||||
def __init__(self, mel_dim):
|
||||
super(PostCBHG, self).__init__()
|
||||
super().__init__()
|
||||
self.cbhg = CBHG(
|
||||
mel_dim,
|
||||
K=8,
|
||||
|
@ -256,7 +237,8 @@ class PostCBHG(nn.Module):
|
|||
conv_projections=[256, mel_dim],
|
||||
highway_features=128,
|
||||
gru_features=128,
|
||||
num_highways=4)
|
||||
num_highways=4,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.cbhg(x)
|
||||
|
@ -289,11 +271,25 @@ class Decoder(nn.Module):
|
|||
# Pylint gets confused by PyTorch conventions here
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
|
||||
def __init__(self, in_channels, frame_channels, r, memory_size, attn_type, attn_windowing,
|
||||
attn_norm, prenet_type, prenet_dropout, forward_attn,
|
||||
trans_agent, forward_attn_mask, location_attn, attn_K,
|
||||
separate_stopnet):
|
||||
super(Decoder, self).__init__()
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
frame_channels,
|
||||
r,
|
||||
memory_size,
|
||||
attn_type,
|
||||
attn_windowing,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
):
|
||||
super().__init__()
|
||||
self.r_init = r
|
||||
self.r = r
|
||||
self.in_channels = in_channels
|
||||
|
@ -305,33 +301,30 @@ class Decoder(nn.Module):
|
|||
self.query_dim = 256
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
prenet_dim = frame_channels * self.memory_size if self.use_memory_queue else frame_channels
|
||||
self.prenet = Prenet(
|
||||
prenet_dim,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
out_features=[256, 128])
|
||||
self.prenet = Prenet(prenet_dim, prenet_type, prenet_dropout, out_features=[256, 128])
|
||||
# processed_inputs, processed_memory -> |Attention| -> Attention, attention, RNN_State
|
||||
# attention_rnn generates queries for the attention mechanism
|
||||
self.attention_rnn = nn.GRUCell(in_channels + 128, self.query_dim)
|
||||
|
||||
self.attention = init_attn(attn_type=attn_type,
|
||||
query_dim=self.query_dim,
|
||||
embedding_dim=in_channels,
|
||||
attention_dim=128,
|
||||
location_attention=location_attn,
|
||||
attention_location_n_filters=32,
|
||||
attention_location_kernel_size=31,
|
||||
windowing=attn_windowing,
|
||||
norm=attn_norm,
|
||||
forward_attn=forward_attn,
|
||||
trans_agent=trans_agent,
|
||||
forward_attn_mask=forward_attn_mask,
|
||||
attn_K=attn_K)
|
||||
self.attention = init_attn(
|
||||
attn_type=attn_type,
|
||||
query_dim=self.query_dim,
|
||||
embedding_dim=in_channels,
|
||||
attention_dim=128,
|
||||
location_attention=location_attn,
|
||||
attention_location_n_filters=32,
|
||||
attention_location_kernel_size=31,
|
||||
windowing=attn_windowing,
|
||||
norm=attn_norm,
|
||||
forward_attn=forward_attn,
|
||||
trans_agent=trans_agent,
|
||||
forward_attn_mask=forward_attn_mask,
|
||||
attn_K=attn_K,
|
||||
)
|
||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||
self.project_to_decoder_in = nn.Linear(256 + in_channels, 256)
|
||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||
self.decoder_rnns = nn.ModuleList(
|
||||
[nn.GRUCell(256, 256) for _ in range(2)])
|
||||
self.decoder_rnns = nn.ModuleList([nn.GRUCell(256, 256) for _ in range(2)])
|
||||
# RNN_state -> |Linear| -> mel_spec
|
||||
self.proj_to_mel = nn.Linear(256, frame_channels * self.r_init)
|
||||
# learn init values instead of zero init.
|
||||
|
@ -364,8 +357,7 @@ class Decoder(nn.Module):
|
|||
# decoder states
|
||||
self.attention_rnn_hidden = torch.zeros(1, device=inputs.device).repeat(B, 256)
|
||||
self.decoder_rnn_hiddens = [
|
||||
torch.zeros(1, device=inputs.device).repeat(B, 256)
|
||||
for idx in range(len(self.decoder_rnns))
|
||||
torch.zeros(1, device=inputs.device).repeat(B, 256) for idx in range(len(self.decoder_rnns))
|
||||
]
|
||||
self.context_vec = inputs.data.new(B, self.in_channels).zero_()
|
||||
# cache attention inputs
|
||||
|
@ -376,8 +368,7 @@ class Decoder(nn.Module):
|
|||
attentions = torch.stack(attentions).transpose(0, 1)
|
||||
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
|
||||
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
|
||||
outputs = outputs.view(
|
||||
outputs.size(0), -1, self.frame_channels)
|
||||
outputs = outputs.view(outputs.size(0), -1, self.frame_channels)
|
||||
outputs = outputs.transpose(1, 2)
|
||||
return outputs, attentions, stop_tokens
|
||||
|
||||
|
@ -386,18 +377,15 @@ class Decoder(nn.Module):
|
|||
processed_memory = self.prenet(self.memory_input)
|
||||
# Attention RNN
|
||||
self.attention_rnn_hidden = self.attention_rnn(
|
||||
torch.cat((processed_memory, self.context_vec), -1),
|
||||
self.attention_rnn_hidden)
|
||||
self.context_vec = self.attention(
|
||||
self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
|
||||
torch.cat((processed_memory, self.context_vec), -1), self.attention_rnn_hidden
|
||||
)
|
||||
self.context_vec = self.attention(self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
|
||||
# Concat RNN output and attention context vector
|
||||
decoder_input = self.project_to_decoder_in(
|
||||
torch.cat((self.attention_rnn_hidden, self.context_vec), -1))
|
||||
decoder_input = self.project_to_decoder_in(torch.cat((self.attention_rnn_hidden, self.context_vec), -1))
|
||||
|
||||
# Pass through the decoder RNNs
|
||||
for idx in range(len(self.decoder_rnns)):
|
||||
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](
|
||||
decoder_input, self.decoder_rnn_hiddens[idx])
|
||||
self.decoder_rnn_hiddens[idx] = self.decoder_rnns[idx](decoder_input, self.decoder_rnn_hiddens[idx])
|
||||
# Residual connection
|
||||
decoder_input = self.decoder_rnn_hiddens[idx] + decoder_input
|
||||
decoder_output = decoder_input
|
||||
|
@ -418,17 +406,17 @@ class Decoder(nn.Module):
|
|||
if self.use_memory_queue:
|
||||
if self.memory_size > self.r:
|
||||
# memory queue size is larger than number of frames per decoder iter
|
||||
self.memory_input = torch.cat([
|
||||
new_memory, self.memory_input[:, :(
|
||||
self.memory_size - self.r) * self.frame_channels].clone()
|
||||
], dim=-1)
|
||||
self.memory_input = torch.cat(
|
||||
[new_memory, self.memory_input[:, : (self.memory_size - self.r) * self.frame_channels].clone()],
|
||||
dim=-1,
|
||||
)
|
||||
else:
|
||||
# memory queue size smaller than number of frames per decoder iter
|
||||
self.memory_input = new_memory[:, :self.memory_size * self.frame_channels]
|
||||
self.memory_input = new_memory[:, : self.memory_size * self.frame_channels]
|
||||
else:
|
||||
# use only the last frame prediction
|
||||
# assert new_memory.shape[-1] == self.r * self.frame_channels
|
||||
self.memory_input = new_memory[:, self.frame_channels * (self.r - 1):]
|
||||
self.memory_input = new_memory[:, self.frame_channels * (self.r - 1) :]
|
||||
|
||||
def forward(self, inputs, memory, mask):
|
||||
"""
|
||||
|
@ -487,8 +475,7 @@ class Decoder(nn.Module):
|
|||
attentions += [attention]
|
||||
stop_tokens += [stop_token]
|
||||
t += 1
|
||||
if t > inputs.shape[1] / 4 and (stop_token > 0.6
|
||||
or attention[:, -1].item() > 0.6):
|
||||
if t > inputs.shape[1] / 4 and (stop_token > 0.6 or attention[:, -1].item() > 0.6):
|
||||
break
|
||||
if t > self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
|
@ -503,11 +490,10 @@ class StopNet(nn.Module):
|
|||
"""
|
||||
|
||||
def __init__(self, in_features):
|
||||
super(StopNet, self).__init__()
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(0.1)
|
||||
self.linear = nn.Linear(in_features, 1)
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear.weight, gain=torch.nn.init.calculate_gain('linear'))
|
||||
torch.nn.init.xavier_uniform_(self.linear.weight, gain=torch.nn.init.calculate_gain("linear"))
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.dropout(inputs)
|
||||
|
|
|
@ -5,8 +5,8 @@ from .common_layers import Prenet, Linear
|
|||
from .attentions import init_attn
|
||||
|
||||
# NOTE: linter has a problem with the current TF release
|
||||
#pylint: disable=no-value-for-parameter
|
||||
#pylint: disable=unexpected-keyword-arg
|
||||
# pylint: disable=no-value-for-parameter
|
||||
# pylint: disable=unexpected-keyword-arg
|
||||
class ConvBNBlock(nn.Module):
|
||||
r"""Convolutions with Batch Normalization and non-linear activation.
|
||||
|
||||
|
@ -20,19 +20,17 @@ class ConvBNBlock(nn.Module):
|
|||
- input: (B, C_in, T)
|
||||
- output: (B, C_out, T)
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, activation=None):
|
||||
super(ConvBNBlock, self).__init__()
|
||||
super().__init__()
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
padding = (kernel_size - 1) // 2
|
||||
self.convolution1d = nn.Conv1d(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
padding=padding)
|
||||
self.convolution1d = nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding)
|
||||
self.batch_normalization = nn.BatchNorm1d(out_channels, momentum=0.1, eps=1e-5)
|
||||
self.dropout = nn.Dropout(p=0.5)
|
||||
if activation == 'relu':
|
||||
if activation == "relu":
|
||||
self.activation = nn.ReLU()
|
||||
elif activation == 'tanh':
|
||||
elif activation == "tanh":
|
||||
self.activation = nn.Tanh()
|
||||
else:
|
||||
self.activation = nn.Identity()
|
||||
|
@ -55,16 +53,14 @@ class Postnet(nn.Module):
|
|||
- input: (B, C_in, T)
|
||||
- output: (B, C_in, T)
|
||||
"""
|
||||
|
||||
def __init__(self, in_out_channels, num_convs=5):
|
||||
super(Postnet, self).__init__()
|
||||
super().__init__()
|
||||
self.convolutions = nn.ModuleList()
|
||||
self.convolutions.append(
|
||||
ConvBNBlock(in_out_channels, 512, kernel_size=5, activation='tanh'))
|
||||
self.convolutions.append(ConvBNBlock(in_out_channels, 512, kernel_size=5, activation="tanh"))
|
||||
for _ in range(1, num_convs - 1):
|
||||
self.convolutions.append(
|
||||
ConvBNBlock(512, 512, kernel_size=5, activation='tanh'))
|
||||
self.convolutions.append(
|
||||
ConvBNBlock(512, in_out_channels, kernel_size=5, activation=None))
|
||||
self.convolutions.append(ConvBNBlock(512, 512, kernel_size=5, activation="tanh"))
|
||||
self.convolutions.append(ConvBNBlock(512, in_out_channels, kernel_size=5, activation=None))
|
||||
|
||||
def forward(self, x):
|
||||
o = x
|
||||
|
@ -83,18 +79,15 @@ class Encoder(nn.Module):
|
|||
- input: (B, C_in, T)
|
||||
- output: (B, C_in, T)
|
||||
"""
|
||||
|
||||
def __init__(self, in_out_channels=512):
|
||||
super(Encoder, self).__init__()
|
||||
super().__init__()
|
||||
self.convolutions = nn.ModuleList()
|
||||
for _ in range(3):
|
||||
self.convolutions.append(
|
||||
ConvBNBlock(in_out_channels, in_out_channels, 5, 'relu'))
|
||||
self.lstm = nn.LSTM(in_out_channels,
|
||||
int(in_out_channels / 2),
|
||||
num_layers=1,
|
||||
batch_first=True,
|
||||
bias=True,
|
||||
bidirectional=True)
|
||||
self.convolutions.append(ConvBNBlock(in_out_channels, in_out_channels, 5, "relu"))
|
||||
self.lstm = nn.LSTM(
|
||||
in_out_channels, int(in_out_channels / 2), num_layers=1, batch_first=True, bias=True, bidirectional=True
|
||||
)
|
||||
self.rnn_state = None
|
||||
|
||||
def forward(self, x, input_lengths):
|
||||
|
@ -102,9 +95,7 @@ class Encoder(nn.Module):
|
|||
for layer in self.convolutions:
|
||||
o = layer(o)
|
||||
o = o.transpose(1, 2)
|
||||
o = nn.utils.rnn.pack_padded_sequence(o,
|
||||
input_lengths.cpu(),
|
||||
batch_first=True)
|
||||
o = nn.utils.rnn.pack_padded_sequence(o, input_lengths.cpu(), batch_first=True)
|
||||
self.lstm.flatten_parameters()
|
||||
o, _ = self.lstm(o)
|
||||
o, _ = nn.utils.rnn.pad_packed_sequence(o, batch_first=True)
|
||||
|
@ -143,12 +134,27 @@ class Decoder(nn.Module):
|
|||
attn_K (int): number of attention heads for GravesAttention.
|
||||
separate_stopnet (bool): if true, detach stopnet input to prevent gradient flow.
|
||||
"""
|
||||
|
||||
# Pylint gets confused by PyTorch conventions here
|
||||
#pylint: disable=attribute-defined-outside-init
|
||||
def __init__(self, in_channels, frame_channels, r, attn_type, attn_win, attn_norm,
|
||||
prenet_type, prenet_dropout, forward_attn, trans_agent,
|
||||
forward_attn_mask, location_attn, attn_K, separate_stopnet):
|
||||
super(Decoder, self).__init__()
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
frame_channels,
|
||||
r,
|
||||
attn_type,
|
||||
attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
):
|
||||
super().__init__()
|
||||
self.frame_channels = frame_channels
|
||||
self.r_init = r
|
||||
self.r = r
|
||||
|
@ -167,43 +173,36 @@ class Decoder(nn.Module):
|
|||
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
prenet_dim = self.frame_channels
|
||||
self.prenet = Prenet(prenet_dim,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
out_features=[self.prenet_dim, self.prenet_dim],
|
||||
bias=False)
|
||||
self.prenet = Prenet(
|
||||
prenet_dim, prenet_type, prenet_dropout, out_features=[self.prenet_dim, self.prenet_dim], bias=False
|
||||
)
|
||||
|
||||
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_channels,
|
||||
self.query_dim,
|
||||
bias=True)
|
||||
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_channels, self.query_dim, bias=True)
|
||||
|
||||
self.attention = init_attn(attn_type=attn_type,
|
||||
query_dim=self.query_dim,
|
||||
embedding_dim=in_channels,
|
||||
attention_dim=128,
|
||||
location_attention=location_attn,
|
||||
attention_location_n_filters=32,
|
||||
attention_location_kernel_size=31,
|
||||
windowing=attn_win,
|
||||
norm=attn_norm,
|
||||
forward_attn=forward_attn,
|
||||
trans_agent=trans_agent,
|
||||
forward_attn_mask=forward_attn_mask,
|
||||
attn_K=attn_K)
|
||||
self.attention = init_attn(
|
||||
attn_type=attn_type,
|
||||
query_dim=self.query_dim,
|
||||
embedding_dim=in_channels,
|
||||
attention_dim=128,
|
||||
location_attention=location_attn,
|
||||
attention_location_n_filters=32,
|
||||
attention_location_kernel_size=31,
|
||||
windowing=attn_win,
|
||||
norm=attn_norm,
|
||||
forward_attn=forward_attn,
|
||||
trans_agent=trans_agent,
|
||||
forward_attn_mask=forward_attn_mask,
|
||||
attn_K=attn_K,
|
||||
)
|
||||
|
||||
self.decoder_rnn = nn.LSTMCell(self.query_dim + in_channels,
|
||||
self.decoder_rnn_dim,
|
||||
bias=True)
|
||||
self.decoder_rnn = nn.LSTMCell(self.query_dim + in_channels, self.decoder_rnn_dim, bias=True)
|
||||
|
||||
self.linear_projection = Linear(self.decoder_rnn_dim + in_channels,
|
||||
self.frame_channels * self.r_init)
|
||||
self.linear_projection = Linear(self.decoder_rnn_dim + in_channels, self.frame_channels * self.r_init)
|
||||
|
||||
self.stopnet = nn.Sequential(
|
||||
nn.Dropout(0.1),
|
||||
Linear(self.decoder_rnn_dim + self.frame_channels * self.r_init,
|
||||
1,
|
||||
bias=True,
|
||||
init_gain='sigmoid'))
|
||||
Linear(self.decoder_rnn_dim + self.frame_channels * self.r_init, 1, bias=True, init_gain="sigmoid"),
|
||||
)
|
||||
self.memory_truncated = None
|
||||
|
||||
def set_r(self, new_r):
|
||||
|
@ -211,24 +210,18 @@ class Decoder(nn.Module):
|
|||
|
||||
def get_go_frame(self, inputs):
|
||||
B = inputs.size(0)
|
||||
memory = torch.zeros(1, device=inputs.device).repeat(
|
||||
B, self.frame_channels * self.r)
|
||||
memory = torch.zeros(1, device=inputs.device).repeat(B, self.frame_channels * self.r)
|
||||
return memory
|
||||
|
||||
def _init_states(self, inputs, mask, keep_states=False):
|
||||
B = inputs.size(0)
|
||||
# T = inputs.size(1)
|
||||
if not keep_states:
|
||||
self.query = torch.zeros(1, device=inputs.device).repeat(
|
||||
B, self.query_dim)
|
||||
self.attention_rnn_cell_state = torch.zeros(
|
||||
1, device=inputs.device).repeat(B, self.query_dim)
|
||||
self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat(
|
||||
B, self.decoder_rnn_dim)
|
||||
self.decoder_cell = torch.zeros(1, device=inputs.device).repeat(
|
||||
B, self.decoder_rnn_dim)
|
||||
self.context = torch.zeros(1, device=inputs.device).repeat(
|
||||
B, self.encoder_embedding_dim)
|
||||
self.query = torch.zeros(1, device=inputs.device).repeat(B, self.query_dim)
|
||||
self.attention_rnn_cell_state = torch.zeros(1, device=inputs.device).repeat(B, self.query_dim)
|
||||
self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat(B, self.decoder_rnn_dim)
|
||||
self.decoder_cell = torch.zeros(1, device=inputs.device).repeat(B, self.decoder_rnn_dim)
|
||||
self.context = torch.zeros(1, device=inputs.device).repeat(B, self.encoder_embedding_dim)
|
||||
self.inputs = inputs
|
||||
self.processed_inputs = self.attention.preprocess_inputs(inputs)
|
||||
self.mask = mask
|
||||
|
@ -254,38 +247,36 @@ class Decoder(nn.Module):
|
|||
|
||||
def _update_memory(self, memory):
|
||||
if len(memory.shape) == 2:
|
||||
return memory[:, self.frame_channels * (self.r - 1):]
|
||||
return memory[:, :, self.frame_channels * (self.r - 1):]
|
||||
return memory[:, self.frame_channels * (self.r - 1) :]
|
||||
return memory[:, :, self.frame_channels * (self.r - 1) :]
|
||||
|
||||
def decode(self, memory):
|
||||
'''
|
||||
shapes:
|
||||
- memory: B x r * self.frame_channels
|
||||
'''
|
||||
"""
|
||||
shapes:
|
||||
- memory: B x r * self.frame_channels
|
||||
"""
|
||||
# self.context: B x D_en
|
||||
# query_input: B x D_en + (r * self.frame_channels)
|
||||
query_input = torch.cat((memory, self.context), -1)
|
||||
# self.query and self.attention_rnn_cell_state : B x D_attn_rnn
|
||||
self.query, self.attention_rnn_cell_state = self.attention_rnn(
|
||||
query_input, (self.query, self.attention_rnn_cell_state))
|
||||
self.query = F.dropout(self.query, self.p_attention_dropout,
|
||||
self.training)
|
||||
query_input, (self.query, self.attention_rnn_cell_state)
|
||||
)
|
||||
self.query = F.dropout(self.query, self.p_attention_dropout, self.training)
|
||||
self.attention_rnn_cell_state = F.dropout(
|
||||
self.attention_rnn_cell_state, self.p_attention_dropout,
|
||||
self.training)
|
||||
self.attention_rnn_cell_state, self.p_attention_dropout, self.training
|
||||
)
|
||||
# B x D_en
|
||||
self.context = self.attention(self.query, self.inputs,
|
||||
self.processed_inputs, self.mask)
|
||||
self.context = self.attention(self.query, self.inputs, self.processed_inputs, self.mask)
|
||||
# B x (D_en + D_attn_rnn)
|
||||
decoder_rnn_input = torch.cat((self.query, self.context), -1)
|
||||
# self.decoder_hidden and self.decoder_cell: B x D_decoder_rnn
|
||||
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
||||
decoder_rnn_input, (self.decoder_hidden, self.decoder_cell))
|
||||
self.decoder_hidden = F.dropout(self.decoder_hidden,
|
||||
self.p_decoder_dropout, self.training)
|
||||
decoder_rnn_input, (self.decoder_hidden, self.decoder_cell)
|
||||
)
|
||||
self.decoder_hidden = F.dropout(self.decoder_hidden, self.p_decoder_dropout, self.training)
|
||||
# B x (D_decoder_rnn + D_en)
|
||||
decoder_hidden_context = torch.cat((self.decoder_hidden, self.context),
|
||||
dim=1)
|
||||
decoder_hidden_context = torch.cat((self.decoder_hidden, self.context), dim=1)
|
||||
# B x (self.r * self.frame_channels)
|
||||
decoder_output = self.linear_projection(decoder_hidden_context)
|
||||
# B x (D_decoder_rnn + (self.r * self.frame_channels))
|
||||
|
@ -295,7 +286,7 @@ class Decoder(nn.Module):
|
|||
else:
|
||||
stop_token = self.stopnet(stopnet_input)
|
||||
# select outputs for the reduction rate self.r
|
||||
decoder_output = decoder_output[:, :self.r * self.frame_channels]
|
||||
decoder_output = decoder_output[:, : self.r * self.frame_channels]
|
||||
return decoder_output, self.attention.attention_weights, stop_token
|
||||
|
||||
def forward(self, inputs, memories, mask):
|
||||
|
@ -329,8 +320,7 @@ class Decoder(nn.Module):
|
|||
stop_tokens += [stop_token.squeeze(1)]
|
||||
alignments += [attention_weights]
|
||||
|
||||
outputs, stop_tokens, alignments = self._parse_outputs(
|
||||
outputs, stop_tokens, alignments)
|
||||
outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments)
|
||||
return outputs, alignments, stop_tokens
|
||||
|
||||
def inference(self, inputs):
|
||||
|
@ -369,8 +359,7 @@ class Decoder(nn.Module):
|
|||
memory = self._update_memory(decoder_output)
|
||||
t += 1
|
||||
|
||||
outputs, stop_tokens, alignments = self._parse_outputs(
|
||||
outputs, stop_tokens, alignments)
|
||||
outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments)
|
||||
|
||||
return outputs, alignments, stop_tokens
|
||||
|
||||
|
@ -404,8 +393,7 @@ class Decoder(nn.Module):
|
|||
self.memory_truncated = decoder_output
|
||||
t += 1
|
||||
|
||||
outputs, stop_tokens, alignments = self._parse_outputs(
|
||||
outputs, stop_tokens, alignments)
|
||||
outputs, stop_tokens, alignments = self._parse_outputs(outputs, stop_tokens, alignments)
|
||||
|
||||
return outputs, alignments, stop_tokens
|
||||
|
||||
|
|
|
@ -62,39 +62,27 @@ class AlignTTS(nn.Module):
|
|||
# pylint: disable=dangerous-default-value
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
out_channels,
|
||||
hidden_channels=256,
|
||||
hidden_channels_dp=256,
|
||||
encoder_type='fftransformer',
|
||||
encoder_params={
|
||||
'hidden_channels_ffn': 1024,
|
||||
'num_heads': 2,
|
||||
'num_layers': 6,
|
||||
'dropout_p': 0.1
|
||||
},
|
||||
decoder_type='fftransformer',
|
||||
decoder_params={
|
||||
'hidden_channels_ffn': 1024,
|
||||
'num_heads': 2,
|
||||
'num_layers': 6,
|
||||
'dropout_p': 0.1
|
||||
},
|
||||
length_scale=1,
|
||||
num_speakers=0,
|
||||
external_c=False,
|
||||
c_in_channels=0):
|
||||
self,
|
||||
num_chars,
|
||||
out_channels,
|
||||
hidden_channels=256,
|
||||
hidden_channels_dp=256,
|
||||
encoder_type="fftransformer",
|
||||
encoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
|
||||
decoder_type="fftransformer",
|
||||
decoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
|
||||
length_scale=1,
|
||||
num_speakers=0,
|
||||
external_c=False,
|
||||
c_in_channels=0,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.length_scale = float(length_scale) if isinstance(
|
||||
length_scale, int) else length_scale
|
||||
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
|
||||
self.emb = nn.Embedding(num_chars, hidden_channels)
|
||||
self.pos_encoder = PositionalEncoding(hidden_channels)
|
||||
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type,
|
||||
encoder_params, c_in_channels)
|
||||
self.decoder = Decoder(out_channels, hidden_channels, decoder_type,
|
||||
decoder_params)
|
||||
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels)
|
||||
self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params)
|
||||
self.duration_predictor = DurationPredictor(hidden_channels_dp)
|
||||
|
||||
self.mod_layer = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||
|
@ -111,13 +99,13 @@ class AlignTTS(nn.Module):
|
|||
@staticmethod
|
||||
def compute_log_probs(mu, log_sigma, y):
|
||||
# pylint: disable=protected-access, c-extension-no-member
|
||||
y = y.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D]
|
||||
mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
||||
log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
||||
y = y.transpose(1, 2).unsqueeze(1) # [B, 1, T1, D]
|
||||
mu = mu.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
||||
log_sigma = log_sigma.transpose(1, 2).unsqueeze(2) # [B, T2, 1, D]
|
||||
expanded_y, expanded_mu = torch.broadcast_tensors(y, mu)
|
||||
exponential = -0.5 * torch.mean(torch._C._nn.mse_loss(
|
||||
expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2),
|
||||
dim=-1) # B, L, T
|
||||
exponential = -0.5 * torch.mean(
|
||||
torch._C._nn.mse_loss(expanded_y, expanded_mu, 0) / torch.pow(log_sigma.exp(), 2), dim=-1
|
||||
) # B, L, T
|
||||
logp = exponential - 0.5 * log_sigma.mean(dim=-1)
|
||||
return logp
|
||||
|
||||
|
@ -151,9 +139,7 @@ class AlignTTS(nn.Module):
|
|||
[1, 0, 0, 0, 0, 0, 0]]
|
||||
"""
|
||||
attn = self.convert_dr_to_align(dr, x_mask, y_mask)
|
||||
o_en_ex = torch.matmul(
|
||||
attn.squeeze(1).transpose(1, 2), en.transpose(1,
|
||||
2)).transpose(1, 2)
|
||||
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
|
||||
return o_en_ex, attn
|
||||
|
||||
def format_durations(self, o_dr_log, x_mask):
|
||||
|
@ -170,12 +156,12 @@ class AlignTTS(nn.Module):
|
|||
|
||||
def _sum_speaker_embedding(self, x, g):
|
||||
# project g to decoder dim.
|
||||
if hasattr(self, 'proj_g'):
|
||||
if hasattr(self, "proj_g"):
|
||||
g = self.proj_g(g)
|
||||
return x + g
|
||||
|
||||
def _forward_encoder(self, x, x_lengths, g=None):
|
||||
if hasattr(self, 'emb_g'):
|
||||
if hasattr(self, "emb_g"):
|
||||
g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1]
|
||||
|
||||
if g is not None:
|
||||
|
@ -187,8 +173,7 @@ class AlignTTS(nn.Module):
|
|||
x_emb = torch.transpose(x_emb, 1, -1)
|
||||
|
||||
# compute sequence masks
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]),
|
||||
1).to(x.dtype)
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
|
||||
|
||||
# encoder pass
|
||||
o_en = self.encoder(x_emb, x_mask)
|
||||
|
@ -201,12 +186,11 @@ class AlignTTS(nn.Module):
|
|||
return o_en, o_en_dp, x_mask, g
|
||||
|
||||
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
|
||||
1).to(o_en_dp.dtype)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
|
||||
# expand o_en with durations
|
||||
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
||||
# positional encoding
|
||||
if hasattr(self, 'pos_encoder'):
|
||||
if hasattr(self, "pos_encoder"):
|
||||
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
|
||||
# speaker embedding
|
||||
if g is not None:
|
||||
|
@ -218,10 +202,8 @@ class AlignTTS(nn.Module):
|
|||
def _forward_mdn(self, o_en, y, y_lengths, x_mask):
|
||||
# MAS potentials and alignment
|
||||
mu, log_sigma = self.mdn_block(o_en)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
|
||||
1).to(o_en.dtype)
|
||||
dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask,
|
||||
y_mask)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
|
||||
dr_mas, logp = self.compute_align_path(mu, log_sigma, y, x_mask, y_mask)
|
||||
return dr_mas, mu, log_sigma, logp
|
||||
|
||||
def forward(self, x, x_lengths, y, y_lengths, phase=None, g=None): # pylint: disable=unused-argument
|
||||
|
@ -237,56 +219,31 @@ class AlignTTS(nn.Module):
|
|||
if phase == 0:
|
||||
# train encoder and MDN
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(
|
||||
o_en, y, y_lengths, x_mask)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
|
||||
1).to(o_en_dp.dtype)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
|
||||
attn = self.convert_dr_to_align(dr_mas, x_mask, y_mask)
|
||||
elif phase == 1:
|
||||
# train decoder
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
dr_mas, _, _, _ = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en.detach(),
|
||||
o_en_dp.detach(),
|
||||
dr_mas.detach(),
|
||||
x_mask,
|
||||
y_lengths,
|
||||
g=g)
|
||||
o_de, attn = self._forward_decoder(o_en.detach(), o_en_dp.detach(), dr_mas.detach(), x_mask, y_lengths, g=g)
|
||||
elif phase == 2:
|
||||
# train the whole except duration predictor
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(
|
||||
o_en, y, y_lengths, x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en,
|
||||
o_en_dp,
|
||||
dr_mas,
|
||||
x_mask,
|
||||
y_lengths,
|
||||
g=g)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
|
||||
elif phase == 3:
|
||||
# train duration predictor
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
o_dr_log = self.duration_predictor(x, x_mask)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(
|
||||
o_en, y, y_lengths, x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en,
|
||||
o_en_dp,
|
||||
dr_mas,
|
||||
x_mask,
|
||||
y_lengths,
|
||||
g=g)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
|
||||
o_dr_log = o_dr_log.squeeze(1)
|
||||
else:
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(
|
||||
o_en, y, y_lengths, x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en,
|
||||
o_en_dp,
|
||||
dr_mas,
|
||||
x_mask,
|
||||
y_lengths,
|
||||
g=g)
|
||||
dr_mas, mu, log_sigma, logp = self._forward_mdn(o_en, y, y_lengths, x_mask)
|
||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr_mas, x_mask, y_lengths, g=g)
|
||||
o_dr_log = o_dr_log.squeeze(1)
|
||||
dr_mas_log = torch.log(dr_mas + 1).squeeze(1)
|
||||
return o_de, o_dr_log, dr_mas_log, attn, mu, log_sigma, logp
|
||||
|
@ -307,17 +264,14 @@ class AlignTTS(nn.Module):
|
|||
# duration predictor pass
|
||||
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||
y_lengths = o_dr.sum(1)
|
||||
o_de, attn = self._forward_decoder(o_en,
|
||||
o_en_dp,
|
||||
o_dr,
|
||||
x_mask,
|
||||
y_lengths,
|
||||
g=g)
|
||||
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
|
||||
return o_de, attn
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
self.load_state_dict(state['model'])
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
|
|
@ -34,28 +34,31 @@ class GlowTTS(nn.Module):
|
|||
encoder_params (dict): encoder module parameters.
|
||||
external_speaker_embedding_dim (int): channels of external speaker embedding vectors.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
hidden_channels_enc,
|
||||
hidden_channels_dec,
|
||||
use_encoder_prenet,
|
||||
hidden_channels_dp,
|
||||
out_channels,
|
||||
num_flow_blocks_dec=12,
|
||||
kernel_size_dec=5,
|
||||
dilation_rate=5,
|
||||
num_block_layers=4,
|
||||
dropout_p_dp=0.1,
|
||||
dropout_p_dec=0.05,
|
||||
num_speakers=0,
|
||||
c_in_channels=0,
|
||||
num_splits=4,
|
||||
num_squeeze=1,
|
||||
sigmoid_scale=False,
|
||||
mean_only=False,
|
||||
encoder_type="transformer",
|
||||
encoder_params=None,
|
||||
external_speaker_embedding_dim=None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
hidden_channels_enc,
|
||||
hidden_channels_dec,
|
||||
use_encoder_prenet,
|
||||
hidden_channels_dp,
|
||||
out_channels,
|
||||
num_flow_blocks_dec=12,
|
||||
kernel_size_dec=5,
|
||||
dilation_rate=5,
|
||||
num_block_layers=4,
|
||||
dropout_p_dp=0.1,
|
||||
dropout_p_dec=0.05,
|
||||
num_speakers=0,
|
||||
c_in_channels=0,
|
||||
num_splits=4,
|
||||
num_squeeze=1,
|
||||
sigmoid_scale=False,
|
||||
mean_only=False,
|
||||
encoder_type="transformer",
|
||||
encoder_params=None,
|
||||
external_speaker_embedding_dim=None,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.num_chars = num_chars
|
||||
|
@ -78,7 +81,7 @@ class GlowTTS(nn.Module):
|
|||
|
||||
# model constants.
|
||||
self.noise_scale = 0.33 # defines the noise variance applied to the random z vector at inference.
|
||||
self.length_scale = 1. # scaler for the duration predictor. The larger it is, the slower the speech.
|
||||
self.length_scale = 1.0 # scaler for the duration predictor. The larger it is, the slower the speech.
|
||||
self.external_speaker_embedding_dim = external_speaker_embedding_dim
|
||||
|
||||
# if is a multispeaker and c_in_channels is 0, set to 256
|
||||
|
@ -88,28 +91,32 @@ class GlowTTS(nn.Module):
|
|||
elif self.external_speaker_embedding_dim:
|
||||
self.c_in_channels = self.external_speaker_embedding_dim
|
||||
|
||||
self.encoder = Encoder(num_chars,
|
||||
out_channels=out_channels,
|
||||
hidden_channels=hidden_channels_enc,
|
||||
hidden_channels_dp=hidden_channels_dp,
|
||||
encoder_type=encoder_type,
|
||||
encoder_params=encoder_params,
|
||||
mean_only=mean_only,
|
||||
use_prenet=use_encoder_prenet,
|
||||
dropout_p_dp=dropout_p_dp,
|
||||
c_in_channels=self.c_in_channels)
|
||||
self.encoder = Encoder(
|
||||
num_chars,
|
||||
out_channels=out_channels,
|
||||
hidden_channels=hidden_channels_enc,
|
||||
hidden_channels_dp=hidden_channels_dp,
|
||||
encoder_type=encoder_type,
|
||||
encoder_params=encoder_params,
|
||||
mean_only=mean_only,
|
||||
use_prenet=use_encoder_prenet,
|
||||
dropout_p_dp=dropout_p_dp,
|
||||
c_in_channels=self.c_in_channels,
|
||||
)
|
||||
|
||||
self.decoder = Decoder(out_channels,
|
||||
hidden_channels_dec,
|
||||
kernel_size_dec,
|
||||
dilation_rate,
|
||||
num_flow_blocks_dec,
|
||||
num_block_layers,
|
||||
dropout_p=dropout_p_dec,
|
||||
num_splits=num_splits,
|
||||
num_squeeze=num_squeeze,
|
||||
sigmoid_scale=sigmoid_scale,
|
||||
c_in_channels=self.c_in_channels)
|
||||
self.decoder = Decoder(
|
||||
out_channels,
|
||||
hidden_channels_dec,
|
||||
kernel_size_dec,
|
||||
dilation_rate,
|
||||
num_flow_blocks_dec,
|
||||
num_block_layers,
|
||||
dropout_p=dropout_p_dec,
|
||||
num_splits=num_splits,
|
||||
num_squeeze=num_squeeze,
|
||||
sigmoid_scale=sigmoid_scale,
|
||||
c_in_channels=self.c_in_channels,
|
||||
)
|
||||
|
||||
if num_speakers > 1 and not external_speaker_embedding_dim:
|
||||
# speaker embedding layer
|
||||
|
@ -119,12 +126,12 @@ class GlowTTS(nn.Module):
|
|||
@staticmethod
|
||||
def compute_outputs(attn, o_mean, o_log_scale, x_mask):
|
||||
# compute final values with the computed alignment
|
||||
y_mean = torch.matmul(
|
||||
attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
|
||||
1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
y_log_scale = torch.matmul(
|
||||
attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(
|
||||
1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
y_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(
|
||||
1, 2
|
||||
) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
y_log_scale = torch.matmul(attn.squeeze(1).transpose(1, 2), o_log_scale.transpose(1, 2)).transpose(
|
||||
1, 2
|
||||
) # [b, t', t], [b, t, d] -> [b, d, t']
|
||||
# compute total duration with adjustment
|
||||
o_attn_dur = torch.log(1 + torch.sum(attn, -1)) * x_mask
|
||||
return y_mean, y_log_scale, o_attn_dur
|
||||
|
@ -144,37 +151,27 @@ class GlowTTS(nn.Module):
|
|||
if self.external_speaker_embedding_dim:
|
||||
g = F.normalize(g).unsqueeze(-1)
|
||||
else:
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1)# [b, h, 1]
|
||||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h, 1]
|
||||
|
||||
# embedding pass
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
|
||||
x_lengths,
|
||||
g=g)
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
||||
# drop redisual frames wrt num_squeeze and set y_lengths.
|
||||
y, y_lengths, y_max_length, attn = self.preprocess(
|
||||
y, y_lengths, y_max_length, None)
|
||||
y, y_lengths, y_max_length, attn = self.preprocess(y, y_lengths, y_max_length, None)
|
||||
# create masks
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
|
||||
1).to(x_mask.dtype)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
# decoder pass
|
||||
z, logdet = self.decoder(y, y_mask, g=g, reverse=False)
|
||||
# find the alignment path
|
||||
with torch.no_grad():
|
||||
o_scale = torch.exp(-2 * o_log_scale)
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale,
|
||||
[1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 *
|
||||
(z**2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2),
|
||||
z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale,
|
||||
[1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (z ** 2)) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), z) # [b, t, d] x [b, d, t'] = [b, t, t']
|
||||
logp4 = torch.sum(-0.5 * (o_mean ** 2) * o_scale, [1]).unsqueeze(-1) # [b, t, 1]
|
||||
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
|
||||
attn = maximum_path(logp,
|
||||
attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
|
||||
attn, o_mean, o_log_scale, x_mask)
|
||||
attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach()
|
||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||
return z, logdet, y_mean, y_log_scale, attn, o_dur_log, o_attn_dur
|
||||
|
||||
|
@ -187,26 +184,20 @@ class GlowTTS(nn.Module):
|
|||
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
|
||||
|
||||
# embedding pass
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x,
|
||||
x_lengths,
|
||||
g=g)
|
||||
o_mean, o_log_scale, o_dur_log, x_mask = self.encoder(x, x_lengths, g=g)
|
||||
# compute output durations
|
||||
w = (torch.exp(o_dur_log) - 1) * x_mask * self.length_scale
|
||||
w_ceil = torch.ceil(w)
|
||||
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
||||
y_max_length = None
|
||||
# compute masks
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length),
|
||||
1).to(x_mask.dtype)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(x_mask.dtype)
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
# compute attention mask
|
||||
attn = generate_path(w_ceil.squeeze(1),
|
||||
attn_mask.squeeze(1)).unsqueeze(1)
|
||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(
|
||||
attn, o_mean, o_log_scale, x_mask)
|
||||
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
|
||||
y_mean, y_log_scale, o_attn_dur = self.compute_outputs(attn, o_mean, o_log_scale, x_mask)
|
||||
|
||||
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) *
|
||||
self.noise_scale) * y_mask
|
||||
z = (y_mean + torch.exp(y_log_scale) * torch.randn_like(y_mean) * self.noise_scale) * y_mask
|
||||
# decoder pass
|
||||
y, logdet = self.decoder(z, y_mask, g=g, reverse=True)
|
||||
attn = attn.squeeze(1).permute(0, 2, 1)
|
||||
|
@ -224,9 +215,11 @@ class GlowTTS(nn.Module):
|
|||
def store_inverse(self):
|
||||
self.decoder.store_inverse()
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
self.load_state_dict(state['model'])
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
self.store_inverse()
|
||||
|
|
|
@ -34,42 +34,37 @@ class SpeedySpeech(nn.Module):
|
|||
external_c (bool, optional): enable external speaker embeddings. Defaults to False.
|
||||
c_in_channels (int, optional): number of channels in speaker embedding vectors. Defaults to 0.
|
||||
"""
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
positional_encoding=True,
|
||||
length_scale=1,
|
||||
encoder_type='residual_conv_bn',
|
||||
encoder_params={
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 13
|
||||
},
|
||||
decoder_type='residual_conv_bn',
|
||||
decoder_params={
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4, 8] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 17
|
||||
},
|
||||
num_speakers=0,
|
||||
external_c=False,
|
||||
c_in_channels=0):
|
||||
self,
|
||||
num_chars,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
positional_encoding=True,
|
||||
length_scale=1,
|
||||
encoder_type="residual_conv_bn",
|
||||
encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13},
|
||||
decoder_type="residual_conv_bn",
|
||||
decoder_params={
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4, 8] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 17,
|
||||
},
|
||||
num_speakers=0,
|
||||
external_c=False,
|
||||
c_in_channels=0,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
|
||||
self.emb = nn.Embedding(num_chars, hidden_channels)
|
||||
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type,
|
||||
encoder_params, c_in_channels)
|
||||
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels)
|
||||
if positional_encoding:
|
||||
self.pos_encoder = PositionalEncoding(hidden_channels)
|
||||
self.decoder = Decoder(out_channels, hidden_channels,
|
||||
decoder_type, decoder_params)
|
||||
self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params)
|
||||
self.duration_predictor = DurationPredictor(hidden_channels + c_in_channels)
|
||||
|
||||
if num_speakers > 1 and not external_c:
|
||||
|
@ -97,9 +92,7 @@ class SpeedySpeech(nn.Module):
|
|||
"""
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype)
|
||||
o_en_ex = torch.matmul(
|
||||
attn.squeeze(1).transpose(1, 2), en.transpose(1,
|
||||
2)).transpose(1, 2)
|
||||
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
|
||||
return o_en_ex, attn
|
||||
|
||||
def format_durations(self, o_dr_log, x_mask):
|
||||
|
@ -116,12 +109,12 @@ class SpeedySpeech(nn.Module):
|
|||
|
||||
def _sum_speaker_embedding(self, x, g):
|
||||
# project g to decoder dim.
|
||||
if hasattr(self, 'proj_g'):
|
||||
if hasattr(self, "proj_g"):
|
||||
g = self.proj_g(g)
|
||||
return x + g
|
||||
|
||||
def _forward_encoder(self, x, x_lengths, g=None):
|
||||
if hasattr(self, 'emb_g'):
|
||||
if hasattr(self, "emb_g"):
|
||||
g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1]
|
||||
|
||||
if g is not None:
|
||||
|
@ -133,8 +126,7 @@ class SpeedySpeech(nn.Module):
|
|||
x_emb = torch.transpose(x_emb, 1, -1)
|
||||
|
||||
# compute sequence masks
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]),
|
||||
1).to(x.dtype)
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
|
||||
|
||||
# encoder pass
|
||||
o_en = self.encoder(x_emb, x_mask)
|
||||
|
@ -147,12 +139,11 @@ class SpeedySpeech(nn.Module):
|
|||
return o_en, o_en_dp, x_mask, g
|
||||
|
||||
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
|
||||
1).to(o_en_dp.dtype)
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
|
||||
# expand o_en with durations
|
||||
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
||||
# positional encoding
|
||||
if hasattr(self, 'pos_encoder'):
|
||||
if hasattr(self, "pos_encoder"):
|
||||
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
|
||||
# speaker embedding
|
||||
if g is not None:
|
||||
|
@ -187,7 +178,7 @@ class SpeedySpeech(nn.Module):
|
|||
if x.shape[1] < 13:
|
||||
inference_padding += 13 - x.shape[1]
|
||||
# pad input to prevent dropping the last word
|
||||
x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode='constant', value=0)
|
||||
x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0)
|
||||
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
|
||||
# duration predictor pass
|
||||
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
|
||||
|
@ -196,9 +187,11 @@ class SpeedySpeech(nn.Module):
|
|||
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
|
||||
return o_de, attn
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
self.load_state_dict(state['model'])
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
|
|
@ -47,45 +47,67 @@ class Tacotron(TacotronAbstract):
|
|||
memory_size (int, optional): size of the history queue fed to the prenet. Model feeds the last ```memory_size```
|
||||
output frames to the prenet.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r=5,
|
||||
postnet_output_dim=1025,
|
||||
decoder_output_dim=80,
|
||||
attn_type='original',
|
||||
attn_win=False,
|
||||
attn_norm="sigmoid",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
encoder_in_features=256,
|
||||
decoder_in_features=256,
|
||||
speaker_embedding_dim=None,
|
||||
gst=False,
|
||||
gst_embedding_dim=256,
|
||||
gst_num_heads=4,
|
||||
gst_style_tokens=10,
|
||||
memory_size=5,
|
||||
gst_use_speaker_embedding=False):
|
||||
super(Tacotron,
|
||||
self).__init__(num_chars, num_speakers, r, postnet_output_dim,
|
||||
decoder_output_dim, attn_type, attn_win,
|
||||
attn_norm, prenet_type, prenet_dropout,
|
||||
forward_attn, trans_agent, forward_attn_mask,
|
||||
location_attn, attn_K, separate_stopnet,
|
||||
bidirectional_decoder, double_decoder_consistency,
|
||||
ddc_r, encoder_in_features, decoder_in_features,
|
||||
speaker_embedding_dim, gst, gst_embedding_dim,
|
||||
gst_num_heads, gst_style_tokens, gst_use_speaker_embedding)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r=5,
|
||||
postnet_output_dim=1025,
|
||||
decoder_output_dim=80,
|
||||
attn_type="original",
|
||||
attn_win=False,
|
||||
attn_norm="sigmoid",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
encoder_in_features=256,
|
||||
decoder_in_features=256,
|
||||
speaker_embedding_dim=None,
|
||||
gst=False,
|
||||
gst_embedding_dim=256,
|
||||
gst_num_heads=4,
|
||||
gst_style_tokens=10,
|
||||
memory_size=5,
|
||||
gst_use_speaker_embedding=False,
|
||||
):
|
||||
super().__init__(
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim,
|
||||
decoder_output_dim,
|
||||
attn_type,
|
||||
attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
bidirectional_decoder,
|
||||
double_decoder_consistency,
|
||||
ddc_r,
|
||||
encoder_in_features,
|
||||
decoder_in_features,
|
||||
speaker_embedding_dim,
|
||||
gst,
|
||||
gst_embedding_dim,
|
||||
gst_num_heads,
|
||||
gst_style_tokens,
|
||||
gst_use_speaker_embedding,
|
||||
)
|
||||
|
||||
# speaker embedding layers
|
||||
if self.num_speakers > 1:
|
||||
|
@ -96,7 +118,7 @@ class Tacotron(TacotronAbstract):
|
|||
|
||||
# speaker and gst embeddings is concat in decoder input
|
||||
if self.num_speakers > 1:
|
||||
self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim
|
||||
self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim
|
||||
|
||||
# embedding layer
|
||||
self.embedding = nn.Embedding(num_chars, 256, padding_idx=0)
|
||||
|
@ -104,32 +126,59 @@ class Tacotron(TacotronAbstract):
|
|||
|
||||
# base model layers
|
||||
self.encoder = Encoder(self.encoder_in_features)
|
||||
self.decoder = Decoder(self.decoder_in_features, decoder_output_dim, r,
|
||||
memory_size, attn_type, attn_win, attn_norm,
|
||||
prenet_type, prenet_dropout, forward_attn,
|
||||
trans_agent, forward_attn_mask, location_attn,
|
||||
attn_K, separate_stopnet)
|
||||
self.decoder = Decoder(
|
||||
self.decoder_in_features,
|
||||
decoder_output_dim,
|
||||
r,
|
||||
memory_size,
|
||||
attn_type,
|
||||
attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
)
|
||||
self.postnet = PostCBHG(decoder_output_dim)
|
||||
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2,
|
||||
postnet_output_dim)
|
||||
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim)
|
||||
|
||||
# global style token layers
|
||||
if self.gst:
|
||||
self.gst_layer = GST(num_mel=80,
|
||||
num_heads=gst_num_heads,
|
||||
num_style_tokens=gst_style_tokens,
|
||||
gst_embedding_dim=self.gst_embedding_dim,
|
||||
speaker_embedding_dim=speaker_embedding_dim if self.embeddings_per_sample and self.gst_use_speaker_embedding else None)
|
||||
self.gst_layer = GST(
|
||||
num_mel=80,
|
||||
num_heads=gst_num_heads,
|
||||
num_style_tokens=gst_style_tokens,
|
||||
gst_embedding_dim=self.gst_embedding_dim,
|
||||
speaker_embedding_dim=speaker_embedding_dim
|
||||
if self.embeddings_per_sample and self.gst_use_speaker_embedding
|
||||
else None,
|
||||
)
|
||||
# backward pass decoder
|
||||
if self.bidirectional_decoder:
|
||||
self._init_backward_decoder()
|
||||
# setup DDC
|
||||
if self.double_decoder_consistency:
|
||||
self.coarse_decoder = Decoder(
|
||||
self.decoder_in_features, decoder_output_dim, ddc_r, memory_size,
|
||||
attn_type, attn_win, attn_norm, prenet_type, prenet_dropout,
|
||||
forward_attn, trans_agent, forward_attn_mask, location_attn,
|
||||
attn_K, separate_stopnet)
|
||||
self.decoder_in_features,
|
||||
decoder_output_dim,
|
||||
ddc_r,
|
||||
memory_size,
|
||||
attn_type,
|
||||
attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
)
|
||||
|
||||
def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None, speaker_embeddings=None):
|
||||
"""
|
||||
|
@ -151,9 +200,9 @@ class Tacotron(TacotronAbstract):
|
|||
# global style token
|
||||
if self.gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs,
|
||||
mel_specs,
|
||||
speaker_embeddings if self.gst_use_speaker_embedding else None)
|
||||
encoder_outputs = self.compute_gst(
|
||||
encoder_outputs, mel_specs, speaker_embeddings if self.gst_use_speaker_embedding else None
|
||||
)
|
||||
# speaker embedding
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
|
@ -166,8 +215,7 @@ class Tacotron(TacotronAbstract):
|
|||
# decoder_outputs: B x decoder_in_features x T_out
|
||||
# alignments: B x T_in x encoder_in_features
|
||||
# stop_tokens: B x T_in
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder(
|
||||
encoder_outputs, mel_specs, input_mask)
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask)
|
||||
# sequence masking
|
||||
if output_mask is not None:
|
||||
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
|
||||
|
@ -182,10 +230,26 @@ class Tacotron(TacotronAbstract):
|
|||
decoder_outputs = decoder_outputs.transpose(1, 2).contiguous()
|
||||
if self.bidirectional_decoder:
|
||||
decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||
return (
|
||||
decoder_outputs,
|
||||
postnet_outputs,
|
||||
alignments,
|
||||
stop_tokens,
|
||||
decoder_outputs_backward,
|
||||
alignments_backward,
|
||||
)
|
||||
if self.double_decoder_consistency:
|
||||
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(mel_specs, encoder_outputs, alignments, input_mask)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
|
||||
mel_specs, encoder_outputs, alignments, input_mask
|
||||
)
|
||||
return (
|
||||
decoder_outputs,
|
||||
postnet_outputs,
|
||||
alignments,
|
||||
stop_tokens,
|
||||
decoder_outputs_backward,
|
||||
alignments_backward,
|
||||
)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -194,9 +258,9 @@ class Tacotron(TacotronAbstract):
|
|||
encoder_outputs = self.encoder(inputs)
|
||||
if self.gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs,
|
||||
style_mel,
|
||||
speaker_embeddings if self.gst_use_speaker_embedding else None)
|
||||
encoder_outputs = self.compute_gst(
|
||||
encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None
|
||||
)
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
# B x 1 x speaker_embed_dim
|
||||
|
@ -205,8 +269,7 @@ class Tacotron(TacotronAbstract):
|
|||
# B x 1 x speaker_embed_dim
|
||||
speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1)
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||
encoder_outputs)
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs)
|
||||
postnet_outputs = self.postnet(decoder_outputs)
|
||||
postnet_outputs = self.last_linear(postnet_outputs)
|
||||
decoder_outputs = decoder_outputs.transpose(1, 2)
|
||||
|
|
|
@ -44,44 +44,66 @@ class Tacotron2(TacotronAbstract):
|
|||
gst_style_tokens (int, optional): number of GST tokens. Defaults to 10.
|
||||
gst_use_speaker_embedding (bool, optional): enable/disable inputing speaker embedding to GST. Defaults to False.
|
||||
"""
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim=80,
|
||||
decoder_output_dim=80,
|
||||
attn_type='original',
|
||||
attn_win=False,
|
||||
attn_norm="softmax",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
encoder_in_features=512,
|
||||
decoder_in_features=512,
|
||||
speaker_embedding_dim=None,
|
||||
gst=False,
|
||||
gst_embedding_dim=512,
|
||||
gst_num_heads=4,
|
||||
gst_style_tokens=10,
|
||||
gst_use_speaker_embedding=False):
|
||||
super(Tacotron2,
|
||||
self).__init__(num_chars, num_speakers, r, postnet_output_dim,
|
||||
decoder_output_dim, attn_type, attn_win,
|
||||
attn_norm, prenet_type, prenet_dropout,
|
||||
forward_attn, trans_agent, forward_attn_mask,
|
||||
location_attn, attn_K, separate_stopnet,
|
||||
bidirectional_decoder, double_decoder_consistency,
|
||||
ddc_r, encoder_in_features, decoder_in_features,
|
||||
speaker_embedding_dim, gst, gst_embedding_dim,
|
||||
gst_num_heads, gst_style_tokens, gst_use_speaker_embedding)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim=80,
|
||||
decoder_output_dim=80,
|
||||
attn_type="original",
|
||||
attn_win=False,
|
||||
attn_norm="softmax",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
encoder_in_features=512,
|
||||
decoder_in_features=512,
|
||||
speaker_embedding_dim=None,
|
||||
gst=False,
|
||||
gst_embedding_dim=512,
|
||||
gst_num_heads=4,
|
||||
gst_style_tokens=10,
|
||||
gst_use_speaker_embedding=False,
|
||||
):
|
||||
super().__init__(
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim,
|
||||
decoder_output_dim,
|
||||
attn_type,
|
||||
attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
bidirectional_decoder,
|
||||
double_decoder_consistency,
|
||||
ddc_r,
|
||||
encoder_in_features,
|
||||
decoder_in_features,
|
||||
speaker_embedding_dim,
|
||||
gst,
|
||||
gst_embedding_dim,
|
||||
gst_num_heads,
|
||||
gst_style_tokens,
|
||||
gst_use_speaker_embedding,
|
||||
)
|
||||
|
||||
# speaker embedding layer
|
||||
if self.num_speakers > 1:
|
||||
|
@ -92,36 +114,63 @@ class Tacotron2(TacotronAbstract):
|
|||
|
||||
# speaker and gst embeddings is concat in decoder input
|
||||
if self.num_speakers > 1:
|
||||
self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim
|
||||
self.decoder_in_features += speaker_embedding_dim # add speaker embedding dim
|
||||
|
||||
# embedding layer
|
||||
self.embedding = nn.Embedding(num_chars, 512, padding_idx=0)
|
||||
|
||||
# base model layers
|
||||
self.encoder = Encoder(self.encoder_in_features)
|
||||
self.decoder = Decoder(self.decoder_in_features, self.decoder_output_dim, r, attn_type, attn_win,
|
||||
attn_norm, prenet_type, prenet_dropout,
|
||||
forward_attn, trans_agent, forward_attn_mask,
|
||||
location_attn, attn_K, separate_stopnet)
|
||||
self.decoder = Decoder(
|
||||
self.decoder_in_features,
|
||||
self.decoder_output_dim,
|
||||
r,
|
||||
attn_type,
|
||||
attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
)
|
||||
self.postnet = Postnet(self.postnet_output_dim)
|
||||
|
||||
# global style token layers
|
||||
if self.gst:
|
||||
self.gst_layer = GST(num_mel=80,
|
||||
num_heads=self.gst_num_heads,
|
||||
num_style_tokens=self.gst_style_tokens,
|
||||
gst_embedding_dim=self.gst_embedding_dim,
|
||||
speaker_embedding_dim=speaker_embedding_dim if self.embeddings_per_sample and self.gst_use_speaker_embedding else None)
|
||||
self.gst_layer = GST(
|
||||
num_mel=80,
|
||||
num_heads=self.gst_num_heads,
|
||||
num_style_tokens=self.gst_style_tokens,
|
||||
gst_embedding_dim=self.gst_embedding_dim,
|
||||
speaker_embedding_dim=speaker_embedding_dim
|
||||
if self.embeddings_per_sample and self.gst_use_speaker_embedding
|
||||
else None,
|
||||
)
|
||||
# backward pass decoder
|
||||
if self.bidirectional_decoder:
|
||||
self._init_backward_decoder()
|
||||
# setup DDC
|
||||
if self.double_decoder_consistency:
|
||||
self.coarse_decoder = Decoder(
|
||||
self.decoder_in_features, self.decoder_output_dim, ddc_r, attn_type,
|
||||
attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn,
|
||||
trans_agent, forward_attn_mask, location_attn, attn_K,
|
||||
separate_stopnet)
|
||||
self.decoder_in_features,
|
||||
self.decoder_output_dim,
|
||||
ddc_r,
|
||||
attn_type,
|
||||
attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
forward_attn,
|
||||
trans_agent,
|
||||
forward_attn_mask,
|
||||
location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def shape_outputs(mel_outputs, mel_outputs_postnet, alignments):
|
||||
|
@ -148,9 +197,9 @@ class Tacotron2(TacotronAbstract):
|
|||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||
if self.gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs,
|
||||
mel_specs,
|
||||
speaker_embeddings if self.gst_use_speaker_embedding else None)
|
||||
encoder_outputs = self.compute_gst(
|
||||
encoder_outputs, mel_specs, speaker_embeddings if self.gst_use_speaker_embedding else None
|
||||
)
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
# B x 1 x speaker_embed_dim
|
||||
|
@ -163,8 +212,7 @@ class Tacotron2(TacotronAbstract):
|
|||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
||||
|
||||
# B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder(
|
||||
encoder_outputs, mel_specs, input_mask)
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask)
|
||||
# sequence masking
|
||||
if mel_lengths is not None:
|
||||
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
|
||||
|
@ -175,14 +223,29 @@ class Tacotron2(TacotronAbstract):
|
|||
if output_mask is not None:
|
||||
postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs)
|
||||
# B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in
|
||||
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
|
||||
decoder_outputs, postnet_outputs, alignments)
|
||||
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments)
|
||||
if self.bidirectional_decoder:
|
||||
decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||
return (
|
||||
decoder_outputs,
|
||||
postnet_outputs,
|
||||
alignments,
|
||||
stop_tokens,
|
||||
decoder_outputs_backward,
|
||||
alignments_backward,
|
||||
)
|
||||
if self.double_decoder_consistency:
|
||||
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(mel_specs, encoder_outputs, alignments, input_mask)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens, decoder_outputs_backward, alignments_backward
|
||||
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
|
||||
mel_specs, encoder_outputs, alignments, input_mask
|
||||
)
|
||||
return (
|
||||
decoder_outputs,
|
||||
postnet_outputs,
|
||||
alignments,
|
||||
stop_tokens,
|
||||
decoder_outputs_backward,
|
||||
alignments_backward,
|
||||
)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -192,20 +255,18 @@ class Tacotron2(TacotronAbstract):
|
|||
|
||||
if self.gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs,
|
||||
style_mel,
|
||||
speaker_embeddings if self.gst_use_speaker_embedding else None)
|
||||
encoder_outputs = self.compute_gst(
|
||||
encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None
|
||||
)
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||
encoder_outputs)
|
||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs)
|
||||
postnet_outputs = self.postnet(decoder_outputs)
|
||||
postnet_outputs = decoder_outputs + postnet_outputs
|
||||
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
|
||||
decoder_outputs, postnet_outputs, alignments)
|
||||
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments)
|
||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
||||
|
||||
def inference_truncated(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None):
|
||||
|
@ -217,19 +278,17 @@ class Tacotron2(TacotronAbstract):
|
|||
|
||||
if self.gst:
|
||||
# B x gst_dim
|
||||
encoder_outputs = self.compute_gst(encoder_outputs,
|
||||
style_mel,
|
||||
speaker_embeddings if self.gst_use_speaker_embedding else None)
|
||||
encoder_outputs = self.compute_gst(
|
||||
encoder_outputs, style_mel, speaker_embeddings if self.gst_use_speaker_embedding else None
|
||||
)
|
||||
|
||||
if self.num_speakers > 1:
|
||||
if not self.embeddings_per_sample:
|
||||
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
||||
|
||||
mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(
|
||||
encoder_outputs)
|
||||
mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(encoder_outputs)
|
||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(
|
||||
mel_outputs, mel_outputs_postnet, alignments)
|
||||
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(mel_outputs, mel_outputs_postnet, alignments)
|
||||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
||||
|
|
|
@ -8,34 +8,36 @@ from TTS.tts.utils.generic_utils import sequence_mask
|
|||
|
||||
|
||||
class TacotronAbstract(ABC, nn.Module):
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim=80,
|
||||
decoder_output_dim=80,
|
||||
attn_type='original',
|
||||
attn_win=False,
|
||||
attn_norm="softmax",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
encoder_in_features=512,
|
||||
decoder_in_features=512,
|
||||
speaker_embedding_dim=None,
|
||||
gst=False,
|
||||
gst_embedding_dim=512,
|
||||
gst_num_heads=4,
|
||||
gst_style_tokens=10,
|
||||
gst_use_speaker_embedding=False):
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim=80,
|
||||
decoder_output_dim=80,
|
||||
attn_type="original",
|
||||
attn_win=False,
|
||||
attn_norm="softmax",
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
attn_K=5,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False,
|
||||
double_decoder_consistency=False,
|
||||
ddc_r=None,
|
||||
encoder_in_features=512,
|
||||
decoder_in_features=512,
|
||||
speaker_embedding_dim=None,
|
||||
gst=False,
|
||||
gst_embedding_dim=512,
|
||||
gst_num_heads=4,
|
||||
gst_style_tokens=10,
|
||||
gst_use_speaker_embedding=False,
|
||||
):
|
||||
""" Abstract Tacotron class """
|
||||
super().__init__()
|
||||
self.num_chars = num_chars
|
||||
|
@ -82,7 +84,7 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
|
||||
# global style token
|
||||
if self.gst:
|
||||
self.decoder_in_features += gst_embedding_dim # add gst embedding dim
|
||||
self.decoder_in_features += gst_embedding_dim # add gst embedding dim
|
||||
self.gst_layer = None
|
||||
|
||||
# model states
|
||||
|
@ -121,10 +123,12 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
def inference(self):
|
||||
pass
|
||||
|
||||
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
self.load_state_dict(state['model'])
|
||||
self.decoder.set_r(state['r'])
|
||||
def load_checkpoint(
|
||||
self, config, checkpoint_path, eval=False
|
||||
): # pylint: disable=unused-argument, redefined-builtin
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
self.load_state_dict(state["model"])
|
||||
self.decoder.set_r(state["r"])
|
||||
if eval:
|
||||
self.eval()
|
||||
assert not self.training
|
||||
|
@ -149,25 +153,24 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
def _backward_pass(self, mel_specs, encoder_outputs, mask):
|
||||
""" Run backwards decoder """
|
||||
decoder_outputs_b, alignments_b, _ = self.decoder_backward(
|
||||
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask)
|
||||
encoder_outputs, torch.flip(mel_specs, dims=(1,)), mask
|
||||
)
|
||||
decoder_outputs_b = decoder_outputs_b.transpose(1, 2).contiguous()
|
||||
return decoder_outputs_b, alignments_b
|
||||
|
||||
def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments,
|
||||
input_mask):
|
||||
def _coarse_decoder_pass(self, mel_specs, encoder_outputs, alignments, input_mask):
|
||||
""" Double Decoder Consistency """
|
||||
T = mel_specs.shape[1]
|
||||
if T % self.coarse_decoder.r > 0:
|
||||
padding_size = self.coarse_decoder.r - (T % self.coarse_decoder.r)
|
||||
mel_specs = torch.nn.functional.pad(mel_specs,
|
||||
(0, 0, 0, padding_size, 0, 0))
|
||||
mel_specs = torch.nn.functional.pad(mel_specs, (0, 0, 0, padding_size, 0, 0))
|
||||
decoder_outputs_backward, alignments_backward, _ = self.coarse_decoder(
|
||||
encoder_outputs.detach(), mel_specs, input_mask)
|
||||
encoder_outputs.detach(), mel_specs, input_mask
|
||||
)
|
||||
# scale_factor = self.decoder.r_init / self.decoder.r
|
||||
alignments_backward = torch.nn.functional.interpolate(
|
||||
alignments_backward.transpose(1, 2),
|
||||
size=alignments.shape[1],
|
||||
mode='nearest').transpose(1, 2)
|
||||
alignments_backward.transpose(1, 2), size=alignments.shape[1], mode="nearest"
|
||||
).transpose(1, 2)
|
||||
decoder_outputs_backward = decoder_outputs_backward.transpose(1, 2)
|
||||
decoder_outputs_backward = decoder_outputs_backward[:, :T, :]
|
||||
return decoder_outputs_backward, alignments_backward
|
||||
|
@ -179,20 +182,17 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
def compute_speaker_embedding(self, speaker_ids):
|
||||
""" Compute speaker embedding vectors """
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is None:
|
||||
raise RuntimeError(
|
||||
" [!] Model has speaker embedding layer but speaker_id is not provided"
|
||||
)
|
||||
raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided")
|
||||
if hasattr(self, "speaker_embedding") and speaker_ids is not None:
|
||||
self.speaker_embeddings = self.speaker_embedding(speaker_ids).unsqueeze(1)
|
||||
if hasattr(self, "speaker_project_mel") and speaker_ids is not None:
|
||||
self.speaker_embeddings_projected = self.speaker_project_mel(
|
||||
self.speaker_embeddings).squeeze(1)
|
||||
self.speaker_embeddings_projected = self.speaker_project_mel(self.speaker_embeddings).squeeze(1)
|
||||
|
||||
def compute_gst(self, inputs, style_input, speaker_embedding=None):
|
||||
""" Compute global style token """
|
||||
device = inputs.device
|
||||
if isinstance(style_input, dict):
|
||||
query = torch.zeros(1, 1, self.gst_embedding_dim//2).to(device)
|
||||
query = torch.zeros(1, 1, self.gst_embedding_dim // 2).to(device)
|
||||
if speaker_embedding is not None:
|
||||
query = torch.cat([query, speaker_embedding.reshape(1, 1, -1)], dim=-1)
|
||||
|
||||
|
@ -205,20 +205,18 @@ class TacotronAbstract(ABC, nn.Module):
|
|||
elif style_input is None:
|
||||
gst_outputs = torch.zeros(1, 1, self.gst_embedding_dim).to(device)
|
||||
else:
|
||||
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
|
||||
gst_outputs = self.gst_layer(style_input, speaker_embedding) # pylint: disable=not-callable
|
||||
inputs = self._concat_speaker_embedding(inputs, gst_outputs)
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def _add_speaker_embedding(outputs, speaker_embeddings):
|
||||
speaker_embeddings_ = speaker_embeddings.expand(
|
||||
outputs.size(0), outputs.size(1), -1)
|
||||
speaker_embeddings_ = speaker_embeddings.expand(outputs.size(0), outputs.size(1), -1)
|
||||
outputs = outputs + speaker_embeddings_
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _concat_speaker_embedding(outputs, speaker_embeddings):
|
||||
speaker_embeddings_ = speaker_embeddings.expand(
|
||||
outputs.size(0), outputs.size(1), -1)
|
||||
speaker_embeddings_ = speaker_embeddings.expand(outputs.size(0), outputs.size(1), -1)
|
||||
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
|
||||
return outputs
|
||||
|
|
|
@ -1,16 +1,18 @@
|
|||
import tensorflow as tf
|
||||
from tensorflow import keras
|
||||
from tensorflow.python.ops import math_ops
|
||||
|
||||
# from tensorflow_addons.seq2seq import BahdanauAttention
|
||||
|
||||
# NOTE: linter has a problem with the current TF release
|
||||
#pylint: disable=no-value-for-parameter
|
||||
#pylint: disable=unexpected-keyword-arg
|
||||
# pylint: disable=no-value-for-parameter
|
||||
# pylint: disable=unexpected-keyword-arg
|
||||
|
||||
|
||||
class Linear(keras.layers.Layer):
|
||||
def __init__(self, units, use_bias, **kwargs):
|
||||
super(Linear, self).__init__(**kwargs)
|
||||
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer')
|
||||
super().__init__(**kwargs)
|
||||
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name="linear_layer")
|
||||
self.activation = keras.layers.ReLU()
|
||||
|
||||
def call(self, x):
|
||||
|
@ -23,9 +25,11 @@ class Linear(keras.layers.Layer):
|
|||
|
||||
class LinearBN(keras.layers.Layer):
|
||||
def __init__(self, units, use_bias, **kwargs):
|
||||
super(LinearBN, self).__init__(**kwargs)
|
||||
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer')
|
||||
self.batch_normalization = keras.layers.BatchNormalization(axis=-1, momentum=0.90, epsilon=1e-5, name='batch_normalization')
|
||||
super().__init__(**kwargs)
|
||||
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name="linear_layer")
|
||||
self.batch_normalization = keras.layers.BatchNormalization(
|
||||
axis=-1, momentum=0.90, epsilon=1e-5, name="batch_normalization"
|
||||
)
|
||||
self.activation = keras.layers.ReLU()
|
||||
|
||||
def call(self, x, training=None):
|
||||
|
@ -39,22 +43,21 @@ class LinearBN(keras.layers.Layer):
|
|||
|
||||
|
||||
class Prenet(keras.layers.Layer):
|
||||
def __init__(self,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
units,
|
||||
bias,
|
||||
**kwargs):
|
||||
super(Prenet, self).__init__(**kwargs)
|
||||
def __init__(self, prenet_type, prenet_dropout, units, bias, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.prenet_type = prenet_type
|
||||
self.prenet_dropout = prenet_dropout
|
||||
self.linear_layers = []
|
||||
if prenet_type == "bn":
|
||||
self.linear_layers += [LinearBN(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)]
|
||||
self.linear_layers += [
|
||||
LinearBN(unit, use_bias=bias, name=f"linear_layer_{idx}") for idx, unit in enumerate(units)
|
||||
]
|
||||
elif prenet_type == "original":
|
||||
self.linear_layers += [Linear(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)]
|
||||
self.linear_layers += [
|
||||
Linear(unit, use_bias=bias, name=f"linear_layer_{idx}") for idx, unit in enumerate(units)
|
||||
]
|
||||
else:
|
||||
raise RuntimeError(' [!] Unknown prenet type.')
|
||||
raise RuntimeError(" [!] Unknown prenet type.")
|
||||
if prenet_dropout:
|
||||
self.dropout = keras.layers.Dropout(rate=0.5)
|
||||
|
||||
|
@ -80,11 +83,22 @@ def _sigmoid_norm(score):
|
|||
class Attention(keras.layers.Layer):
|
||||
"""TODO: implement forward_attention
|
||||
TODO: location sensitive attention
|
||||
TODO: implement attention windowing """
|
||||
def __init__(self, attn_dim, use_loc_attn, loc_attn_n_filters,
|
||||
loc_attn_kernel_size, use_windowing, norm, use_forward_attn,
|
||||
use_trans_agent, use_forward_attn_mask, **kwargs):
|
||||
super(Attention, self).__init__(**kwargs)
|
||||
TODO: implement attention windowing"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attn_dim,
|
||||
use_loc_attn,
|
||||
loc_attn_n_filters,
|
||||
loc_attn_kernel_size,
|
||||
use_windowing,
|
||||
norm,
|
||||
use_forward_attn,
|
||||
use_trans_agent,
|
||||
use_forward_attn_mask,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.use_loc_attn = use_loc_attn
|
||||
self.loc_attn_n_filters = loc_attn_n_filters
|
||||
self.loc_attn_kernel_size = loc_attn_kernel_size
|
||||
|
@ -93,20 +107,23 @@ class Attention(keras.layers.Layer):
|
|||
self.use_forward_attn = use_forward_attn
|
||||
self.use_trans_agent = use_trans_agent
|
||||
self.use_forward_attn_mask = use_forward_attn_mask
|
||||
self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name='query_layer/linear_layer')
|
||||
self.inputs_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name=f'{self.name}/inputs_layer/linear_layer')
|
||||
self.v = tf.keras.layers.Dense(1, use_bias=True, name='v/linear_layer')
|
||||
self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name="query_layer/linear_layer")
|
||||
self.inputs_layer = tf.keras.layers.Dense(
|
||||
attn_dim, use_bias=False, name=f"{self.name}/inputs_layer/linear_layer"
|
||||
)
|
||||
self.v = tf.keras.layers.Dense(1, use_bias=True, name="v/linear_layer")
|
||||
if use_loc_attn:
|
||||
self.location_conv1d = keras.layers.Conv1D(
|
||||
filters=loc_attn_n_filters,
|
||||
kernel_size=loc_attn_kernel_size,
|
||||
padding='same',
|
||||
padding="same",
|
||||
use_bias=False,
|
||||
name='location_layer/location_conv1d')
|
||||
self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name='location_layer/location_dense')
|
||||
if norm == 'softmax':
|
||||
name="location_layer/location_conv1d",
|
||||
)
|
||||
self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name="location_layer/location_dense")
|
||||
if norm == "softmax":
|
||||
self.norm_func = tf.nn.softmax
|
||||
elif norm == 'sigmoid':
|
||||
elif norm == "sigmoid":
|
||||
self.norm_func = _sigmoid_norm
|
||||
else:
|
||||
raise ValueError("Unknown value for attention norm type")
|
||||
|
@ -118,30 +135,25 @@ class Attention(keras.layers.Layer):
|
|||
attention_old = tf.zeros([batch_size, value_length])
|
||||
states = [attention_cum, attention_old]
|
||||
if self.use_forward_attn:
|
||||
alpha = tf.concat([
|
||||
tf.ones([batch_size, 1]),
|
||||
tf.zeros([batch_size, value_length])[:, :-1] + 1e-7
|
||||
], 1)
|
||||
alpha = tf.concat([tf.ones([batch_size, 1]), tf.zeros([batch_size, value_length])[:, :-1] + 1e-7], 1)
|
||||
states.append(alpha)
|
||||
return tuple(states)
|
||||
|
||||
def process_values(self, values):
|
||||
""" cache values for decoder iterations """
|
||||
#pylint: disable=attribute-defined-outside-init
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
self.processed_values = self.inputs_layer(values)
|
||||
self.values = values
|
||||
|
||||
def get_loc_attn(self, query, states):
|
||||
""" compute location attention, query layer and
|
||||
"""compute location attention, query layer and
|
||||
unnorm. attention weights"""
|
||||
attention_cum, attention_old = states[:2]
|
||||
attn_cat = tf.stack([attention_old, attention_cum], axis=2)
|
||||
|
||||
processed_query = self.query_layer(tf.expand_dims(query, 1))
|
||||
processed_attn = self.location_dense(self.location_conv1d(attn_cat))
|
||||
score = self.v(
|
||||
tf.nn.tanh(self.processed_values + processed_query +
|
||||
processed_attn))
|
||||
score = self.v(tf.nn.tanh(self.processed_values + processed_query + processed_attn))
|
||||
score = tf.squeeze(score, axis=2)
|
||||
return score, processed_query
|
||||
|
||||
|
@ -152,14 +164,14 @@ class Attention(keras.layers.Layer):
|
|||
score = tf.squeeze(score, axis=2)
|
||||
return score, processed_query
|
||||
|
||||
def apply_score_masking(self, score, mask): #pylint: disable=no-self-use
|
||||
def apply_score_masking(self, score, mask): # pylint: disable=no-self-use
|
||||
""" ignore sequence paddings """
|
||||
padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
|
||||
# Bias so padding positions do not contribute to attention distribution.
|
||||
score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
|
||||
score -= 1.0e9 * math_ops.cast(padding_mask, dtype=tf.float32)
|
||||
return score
|
||||
|
||||
def apply_forward_attention(self, alignment, alpha): #pylint: disable=no-self-use
|
||||
def apply_forward_attention(self, alignment, alpha): # pylint: disable=no-self-use
|
||||
# forward attention
|
||||
fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)), constant_values=0.0)
|
||||
# compute transition potentials
|
||||
|
@ -206,7 +218,9 @@ class Attention(keras.layers.Layer):
|
|||
states = self.update_states(states, scores_norm, attn_weights, new_alpha)
|
||||
|
||||
# context_vector shape after sum == (batch_size, hidden_size)
|
||||
context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False)
|
||||
context_vector = tf.matmul(
|
||||
tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False
|
||||
)
|
||||
context_vector = tf.squeeze(context_vector, axis=1)
|
||||
return context_vector, attn_weights, states
|
||||
|
||||
|
@ -230,8 +244,7 @@ class Attention(keras.layers.Layer):
|
|||
# location_attention_filters=32,
|
||||
# location_attention_kernel_size=31):
|
||||
|
||||
# super(LocationSensitiveAttention,
|
||||
# self).__init__(units=units,
|
||||
# super( self).__init__(units=units,
|
||||
# memory=memory,
|
||||
# memory_sequence_length=memory_sequence_length,
|
||||
# normalize=normalize,
|
||||
|
|
|
@ -5,15 +5,17 @@ from TTS.tts.tf.layers.tacotron.common_layers import Prenet, Attention
|
|||
|
||||
|
||||
# NOTE: linter has a problem with the current TF release
|
||||
#pylint: disable=no-value-for-parameter
|
||||
#pylint: disable=unexpected-keyword-arg
|
||||
# pylint: disable=no-value-for-parameter
|
||||
# pylint: disable=unexpected-keyword-arg
|
||||
class ConvBNBlock(keras.layers.Layer):
|
||||
def __init__(self, filters, kernel_size, activation, **kwargs):
|
||||
super(ConvBNBlock, self).__init__(**kwargs)
|
||||
self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding='same', name='convolution1d')
|
||||
self.batch_normalization = keras.layers.BatchNormalization(axis=2, momentum=0.90, epsilon=1e-5, name='batch_normalization')
|
||||
self.dropout = keras.layers.Dropout(rate=0.5, name='dropout')
|
||||
self.activation = keras.layers.Activation(activation, name='activation')
|
||||
super().__init__(**kwargs)
|
||||
self.convolution1d = keras.layers.Conv1D(filters, kernel_size, padding="same", name="convolution1d")
|
||||
self.batch_normalization = keras.layers.BatchNormalization(
|
||||
axis=2, momentum=0.90, epsilon=1e-5, name="batch_normalization"
|
||||
)
|
||||
self.dropout = keras.layers.Dropout(rate=0.5, name="dropout")
|
||||
self.activation = keras.layers.Activation(activation, name="activation")
|
||||
|
||||
def call(self, x, training=None):
|
||||
o = self.convolution1d(x)
|
||||
|
@ -25,12 +27,12 @@ class ConvBNBlock(keras.layers.Layer):
|
|||
|
||||
class Postnet(keras.layers.Layer):
|
||||
def __init__(self, output_filters, num_convs, **kwargs):
|
||||
super(Postnet, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
self.convolutions = []
|
||||
self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name='convolutions_0'))
|
||||
self.convolutions.append(ConvBNBlock(512, 5, "tanh", name="convolutions_0"))
|
||||
for idx in range(1, num_convs - 1):
|
||||
self.convolutions.append(ConvBNBlock(512, 5, 'tanh', name=f'convolutions_{idx}'))
|
||||
self.convolutions.append(ConvBNBlock(output_filters, 5, 'linear', name=f'convolutions_{idx+1}'))
|
||||
self.convolutions.append(ConvBNBlock(512, 5, "tanh", name=f"convolutions_{idx}"))
|
||||
self.convolutions.append(ConvBNBlock(output_filters, 5, "linear", name=f"convolutions_{idx+1}"))
|
||||
|
||||
def call(self, x, training=None):
|
||||
o = x
|
||||
|
@ -41,11 +43,13 @@ class Postnet(keras.layers.Layer):
|
|||
|
||||
class Encoder(keras.layers.Layer):
|
||||
def __init__(self, output_input_dim, **kwargs):
|
||||
super(Encoder, self).__init__(**kwargs)
|
||||
super().__init__(**kwargs)
|
||||
self.convolutions = []
|
||||
for idx in range(3):
|
||||
self.convolutions.append(ConvBNBlock(output_input_dim, 5, 'relu', name=f'convolutions_{idx}'))
|
||||
self.lstm = keras.layers.Bidirectional(keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name='lstm')
|
||||
self.convolutions.append(ConvBNBlock(output_input_dim, 5, "relu", name=f"convolutions_{idx}"))
|
||||
self.lstm = keras.layers.Bidirectional(
|
||||
keras.layers.LSTM(output_input_dim // 2, return_sequences=True, use_bias=True), name="lstm"
|
||||
)
|
||||
|
||||
def call(self, x, training=None):
|
||||
o = x
|
||||
|
@ -56,11 +60,27 @@ class Encoder(keras.layers.Layer):
|
|||
|
||||
|
||||
class Decoder(keras.layers.Layer):
|
||||
#pylint: disable=unused-argument
|
||||
def __init__(self, frame_dim, r, attn_type, use_attn_win, attn_norm, prenet_type,
|
||||
prenet_dropout, use_forward_attn, use_trans_agent, use_forward_attn_mask,
|
||||
use_location_attn, attn_K, separate_stopnet, speaker_emb_dim, enable_tflite, **kwargs):
|
||||
super(Decoder, self).__init__(**kwargs)
|
||||
# pylint: disable=unused-argument
|
||||
def __init__(
|
||||
self,
|
||||
frame_dim,
|
||||
r,
|
||||
attn_type,
|
||||
use_attn_win,
|
||||
attn_norm,
|
||||
prenet_type,
|
||||
prenet_dropout,
|
||||
use_forward_attn,
|
||||
use_trans_agent,
|
||||
use_forward_attn_mask,
|
||||
use_location_attn,
|
||||
attn_K,
|
||||
separate_stopnet,
|
||||
speaker_emb_dim,
|
||||
enable_tflite,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.frame_dim = frame_dim
|
||||
self.r_init = tf.constant(r, dtype=tf.int32)
|
||||
self.r = tf.constant(r, dtype=tf.int32)
|
||||
|
@ -80,30 +100,31 @@ class Decoder(keras.layers.Layer):
|
|||
self.p_attention_dropout = 0.1
|
||||
self.p_decoder_dropout = 0.1
|
||||
|
||||
self.prenet = Prenet(prenet_type,
|
||||
prenet_dropout,
|
||||
[self.prenet_dim, self.prenet_dim],
|
||||
bias=False,
|
||||
name='prenet')
|
||||
self.attention_rnn = keras.layers.LSTMCell(self.query_dim, use_bias=True, name='attention_rnn', )
|
||||
self.prenet = Prenet(prenet_type, prenet_dropout, [self.prenet_dim, self.prenet_dim], bias=False, name="prenet")
|
||||
self.attention_rnn = keras.layers.LSTMCell(
|
||||
self.query_dim,
|
||||
use_bias=True,
|
||||
name="attention_rnn",
|
||||
)
|
||||
self.attention_rnn_dropout = keras.layers.Dropout(0.5)
|
||||
|
||||
# TODO: implement other attn options
|
||||
self.attention = Attention(attn_dim=self.attn_dim,
|
||||
use_loc_attn=True,
|
||||
loc_attn_n_filters=32,
|
||||
loc_attn_kernel_size=31,
|
||||
use_windowing=False,
|
||||
norm=attn_norm,
|
||||
use_forward_attn=use_forward_attn,
|
||||
use_trans_agent=use_trans_agent,
|
||||
use_forward_attn_mask=use_forward_attn_mask,
|
||||
name='attention')
|
||||
self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name='decoder_rnn')
|
||||
self.attention = Attention(
|
||||
attn_dim=self.attn_dim,
|
||||
use_loc_attn=True,
|
||||
loc_attn_n_filters=32,
|
||||
loc_attn_kernel_size=31,
|
||||
use_windowing=False,
|
||||
norm=attn_norm,
|
||||
use_forward_attn=use_forward_attn,
|
||||
use_trans_agent=use_trans_agent,
|
||||
use_forward_attn_mask=use_forward_attn_mask,
|
||||
name="attention",
|
||||
)
|
||||
self.decoder_rnn = keras.layers.LSTMCell(self.decoder_rnn_dim, use_bias=True, name="decoder_rnn")
|
||||
self.decoder_rnn_dropout = keras.layers.Dropout(0.5)
|
||||
self.linear_projection = keras.layers.Dense(self.frame_dim * r, name='linear_projection/linear_layer')
|
||||
self.stopnet = keras.layers.Dense(1, name='stopnet/linear_layer')
|
||||
|
||||
self.linear_projection = keras.layers.Dense(self.frame_dim * r, name="linear_projection/linear_layer")
|
||||
self.stopnet = keras.layers.Dense(1, name="stopnet/linear_layer")
|
||||
|
||||
def set_max_decoder_steps(self, new_max_steps):
|
||||
self.max_decoder_steps = tf.constant(new_max_steps, dtype=tf.int32)
|
||||
|
@ -120,25 +141,31 @@ class Decoder(keras.layers.Layer):
|
|||
attention_states = self.attention.init_states(batch_size, memory_length)
|
||||
return zero_frame, zero_context, attention_rnn_state, decoder_rnn_state, attention_states
|
||||
|
||||
def step(self, prenet_next, states,
|
||||
memory_seq_length=None, training=None):
|
||||
def step(self, prenet_next, states, memory_seq_length=None, training=None):
|
||||
_, context_next, attention_rnn_state, decoder_rnn_state, attention_states = states
|
||||
attention_rnn_input = tf.concat([prenet_next, context_next], -1)
|
||||
attention_rnn_output, attention_rnn_state = \
|
||||
self.attention_rnn(attention_rnn_input,
|
||||
attention_rnn_state, training=training)
|
||||
attention_rnn_output, attention_rnn_state = self.attention_rnn(
|
||||
attention_rnn_input, attention_rnn_state, training=training
|
||||
)
|
||||
attention_rnn_output = self.attention_rnn_dropout(attention_rnn_output, training=training)
|
||||
context, attention, attention_states = self.attention(attention_rnn_output, attention_states, training=training)
|
||||
decoder_rnn_input = tf.concat([attention_rnn_output, context], -1)
|
||||
decoder_rnn_output, decoder_rnn_state = \
|
||||
self.decoder_rnn(decoder_rnn_input, decoder_rnn_state, training=training)
|
||||
decoder_rnn_output, decoder_rnn_state = self.decoder_rnn(
|
||||
decoder_rnn_input, decoder_rnn_state, training=training
|
||||
)
|
||||
decoder_rnn_output = self.decoder_rnn_dropout(decoder_rnn_output, training=training)
|
||||
linear_projection_input = tf.concat([decoder_rnn_output, context], -1)
|
||||
output_frame = self.linear_projection(linear_projection_input, training=training)
|
||||
stopnet_input = tf.concat([decoder_rnn_output, output_frame], -1)
|
||||
stopnet_output = self.stopnet(stopnet_input, training=training)
|
||||
output_frame = output_frame[:, :self.r * self.frame_dim]
|
||||
states = (output_frame[:, self.frame_dim * (self.r - 1):], context, attention_rnn_state, decoder_rnn_state, attention_states)
|
||||
output_frame = output_frame[:, : self.r * self.frame_dim]
|
||||
states = (
|
||||
output_frame[:, self.frame_dim * (self.r - 1) :],
|
||||
context,
|
||||
attention_rnn_state,
|
||||
decoder_rnn_state,
|
||||
attention_states,
|
||||
)
|
||||
return output_frame, stopnet_output, states, attention
|
||||
|
||||
def decode(self, memory, states, frames, memory_seq_length=None):
|
||||
|
@ -157,21 +184,20 @@ class Decoder(keras.layers.Layer):
|
|||
|
||||
def _body(step, memory, prenet_output, states, outputs, stop_tokens, attentions):
|
||||
prenet_next = prenet_output[:, step]
|
||||
output, stop_token, states, attention = self.step(prenet_next,
|
||||
states,
|
||||
memory_seq_length)
|
||||
output, stop_token, states, attention = self.step(prenet_next, states, memory_seq_length)
|
||||
outputs = outputs.write(step, output)
|
||||
attentions = attentions.write(step, attention)
|
||||
stop_tokens = stop_tokens.write(step, stop_token)
|
||||
return step + 1, memory, prenet_output, states, outputs, stop_tokens, attentions
|
||||
_, memory, _, states, outputs, stop_tokens, attentions = \
|
||||
tf.while_loop(lambda *arg: True,
|
||||
_body,
|
||||
loop_vars=(step_count, memory, prenet_output,
|
||||
states, outputs, stop_tokens, attentions),
|
||||
parallel_iterations=32,
|
||||
swap_memory=True,
|
||||
maximum_iterations=num_iter)
|
||||
|
||||
_, memory, _, states, outputs, stop_tokens, attentions = tf.while_loop(
|
||||
lambda *arg: True,
|
||||
_body,
|
||||
loop_vars=(step_count, memory, prenet_output, states, outputs, stop_tokens, attentions),
|
||||
parallel_iterations=32,
|
||||
swap_memory=True,
|
||||
maximum_iterations=num_iter,
|
||||
)
|
||||
|
||||
outputs = outputs.stack()
|
||||
attentions = attentions.stack()
|
||||
|
@ -200,10 +226,7 @@ class Decoder(keras.layers.Layer):
|
|||
def _body(step, memory, states, outputs, stop_tokens, attentions, stop_flag):
|
||||
frame_next = states[0]
|
||||
prenet_next = self.prenet(frame_next, training=False)
|
||||
output, stop_token, states, attention = self.step(prenet_next,
|
||||
states,
|
||||
None,
|
||||
training=False)
|
||||
output, stop_token, states, attention = self.step(prenet_next, states, None, training=False)
|
||||
stop_token = tf.math.sigmoid(stop_token)
|
||||
outputs = outputs.write(step, output)
|
||||
attentions = attentions.write(step, attention)
|
||||
|
@ -213,14 +236,14 @@ class Decoder(keras.layers.Layer):
|
|||
return step + 1, memory, states, outputs, stop_tokens, attentions, stop_flag
|
||||
|
||||
cond = lambda step, m, s, o, st, a, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool))
|
||||
_, memory, states, outputs, stop_tokens, attentions, stop_flag = \
|
||||
tf.while_loop(cond,
|
||||
_body,
|
||||
loop_vars=(step_count, memory, states, outputs,
|
||||
stop_tokens, attentions, stop_flag),
|
||||
parallel_iterations=32,
|
||||
swap_memory=True,
|
||||
maximum_iterations=self.max_decoder_steps)
|
||||
_, memory, states, outputs, stop_tokens, attentions, stop_flag = tf.while_loop(
|
||||
cond,
|
||||
_body,
|
||||
loop_vars=(step_count, memory, states, outputs, stop_tokens, attentions, stop_flag),
|
||||
parallel_iterations=32,
|
||||
swap_memory=True,
|
||||
maximum_iterations=self.max_decoder_steps,
|
||||
)
|
||||
|
||||
outputs = outputs.stack()
|
||||
attentions = attentions.stack()
|
||||
|
@ -238,12 +261,13 @@ class Decoder(keras.layers.Layer):
|
|||
batch_size is 1"""
|
||||
# init states
|
||||
# dynamic_shape is not supported in TFLite
|
||||
outputs = tf.TensorArray(dtype=tf.float32,
|
||||
size=self.max_decoder_steps,
|
||||
element_shape=tf.TensorShape(
|
||||
[self.output_dim]),
|
||||
clear_after_read=False,
|
||||
dynamic_size=False)
|
||||
outputs = tf.TensorArray(
|
||||
dtype=tf.float32,
|
||||
size=self.max_decoder_steps,
|
||||
element_shape=tf.TensorShape([self.output_dim]),
|
||||
clear_after_read=False,
|
||||
dynamic_size=False,
|
||||
)
|
||||
# stop_flags = tf.TensorArray(dtype=tf.bool,
|
||||
# size=self.max_decoder_steps,
|
||||
# element_shape=tf.TensorShape(
|
||||
|
@ -263,10 +287,7 @@ class Decoder(keras.layers.Layer):
|
|||
def _body(step, memory, states, outputs, stop_flag):
|
||||
frame_next = states[0]
|
||||
prenet_next = self.prenet(frame_next, training=False)
|
||||
output, stop_token, states, _ = self.step(prenet_next,
|
||||
states,
|
||||
None,
|
||||
training=False)
|
||||
output, stop_token, states, _ = self.step(prenet_next, states, None, training=False)
|
||||
stop_token = tf.math.sigmoid(stop_token)
|
||||
stop_flag = tf.greater(stop_token, self.stop_thresh)
|
||||
stop_flag = tf.reduce_all(stop_flag)
|
||||
|
@ -276,24 +297,22 @@ class Decoder(keras.layers.Layer):
|
|||
return step + 1, memory, states, outputs, stop_flag
|
||||
|
||||
cond = lambda step, m, s, o, stop_flag: tf.equal(stop_flag, tf.constant(False, dtype=tf.bool))
|
||||
step_count, memory, states, outputs, stop_flag = \
|
||||
tf.while_loop(cond,
|
||||
_body,
|
||||
loop_vars=(step_count, memory, states, outputs,
|
||||
stop_flag),
|
||||
parallel_iterations=32,
|
||||
swap_memory=True,
|
||||
maximum_iterations=self.max_decoder_steps)
|
||||
|
||||
step_count, memory, states, outputs, stop_flag = tf.while_loop(
|
||||
cond,
|
||||
_body,
|
||||
loop_vars=(step_count, memory, states, outputs, stop_flag),
|
||||
parallel_iterations=32,
|
||||
swap_memory=True,
|
||||
maximum_iterations=self.max_decoder_steps,
|
||||
)
|
||||
|
||||
outputs = outputs.stack()
|
||||
outputs = tf.gather(outputs, tf.range(step_count)) # pylint: disable=no-value-for-parameter
|
||||
outputs = tf.gather(outputs, tf.range(step_count)) # pylint: disable=no-value-for-parameter
|
||||
outputs = tf.expand_dims(outputs, axis=[0])
|
||||
outputs = tf.transpose(outputs, [1, 0, 2])
|
||||
outputs = tf.reshape(outputs, [1, -1, self.frame_dim])
|
||||
return outputs, stop_tokens, attentions
|
||||
|
||||
|
||||
def call(self, memory, states, frames=None, memory_seq_length=None, training=False):
|
||||
if training:
|
||||
return self.decode(memory, states, frames, memory_seq_length)
|
||||
|
|
|
@ -5,28 +5,30 @@ from TTS.tts.tf.layers.tacotron.tacotron2 import Encoder, Decoder, Postnet
|
|||
from TTS.tts.tf.utils.tf_utils import shape_list
|
||||
|
||||
|
||||
#pylint: disable=too-many-ancestors, abstract-method
|
||||
# pylint: disable=too-many-ancestors, abstract-method
|
||||
class Tacotron2(keras.models.Model):
|
||||
def __init__(self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim=80,
|
||||
decoder_output_dim=80,
|
||||
attn_type='original',
|
||||
attn_win=False,
|
||||
attn_norm="softmax",
|
||||
attn_K=4,
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False,
|
||||
enable_tflite=False):
|
||||
super(Tacotron2, self).__init__()
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
num_speakers,
|
||||
r,
|
||||
postnet_output_dim=80,
|
||||
decoder_output_dim=80,
|
||||
attn_type="original",
|
||||
attn_win=False,
|
||||
attn_norm="softmax",
|
||||
attn_K=4,
|
||||
prenet_type="original",
|
||||
prenet_dropout=True,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
separate_stopnet=True,
|
||||
bidirectional_decoder=False,
|
||||
enable_tflite=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.r = r
|
||||
self.decoder_output_dim = decoder_output_dim
|
||||
self.postnet_output_dim = postnet_output_dim
|
||||
|
@ -35,26 +37,28 @@ class Tacotron2(keras.models.Model):
|
|||
self.speaker_embed_dim = 256
|
||||
self.enable_tflite = enable_tflite
|
||||
|
||||
self.embedding = keras.layers.Embedding(num_chars, 512, name='embedding')
|
||||
self.encoder = Encoder(512, name='encoder')
|
||||
self.embedding = keras.layers.Embedding(num_chars, 512, name="embedding")
|
||||
self.encoder = Encoder(512, name="encoder")
|
||||
# TODO: most of the decoder args have no use at the momment
|
||||
self.decoder = Decoder(decoder_output_dim,
|
||||
r,
|
||||
attn_type=attn_type,
|
||||
use_attn_win=attn_win,
|
||||
attn_norm=attn_norm,
|
||||
prenet_type=prenet_type,
|
||||
prenet_dropout=prenet_dropout,
|
||||
use_forward_attn=forward_attn,
|
||||
use_trans_agent=trans_agent,
|
||||
use_forward_attn_mask=forward_attn_mask,
|
||||
use_location_attn=location_attn,
|
||||
attn_K=attn_K,
|
||||
separate_stopnet=separate_stopnet,
|
||||
speaker_emb_dim=self.speaker_embed_dim,
|
||||
name='decoder',
|
||||
enable_tflite=enable_tflite)
|
||||
self.postnet = Postnet(postnet_output_dim, 5, name='postnet')
|
||||
self.decoder = Decoder(
|
||||
decoder_output_dim,
|
||||
r,
|
||||
attn_type=attn_type,
|
||||
use_attn_win=attn_win,
|
||||
attn_norm=attn_norm,
|
||||
prenet_type=prenet_type,
|
||||
prenet_dropout=prenet_dropout,
|
||||
use_forward_attn=forward_attn,
|
||||
use_trans_agent=trans_agent,
|
||||
use_forward_attn_mask=forward_attn_mask,
|
||||
use_location_attn=location_attn,
|
||||
attn_K=attn_K,
|
||||
separate_stopnet=separate_stopnet,
|
||||
speaker_emb_dim=self.speaker_embed_dim,
|
||||
name="decoder",
|
||||
enable_tflite=enable_tflite,
|
||||
)
|
||||
self.postnet = Postnet(postnet_output_dim, 5, name="postnet")
|
||||
|
||||
@tf.function(experimental_relax_shapes=True)
|
||||
def call(self, characters, text_lengths=None, frames=None, training=None):
|
||||
|
@ -62,14 +66,16 @@ class Tacotron2(keras.models.Model):
|
|||
return self.training(characters, text_lengths, frames)
|
||||
if not training:
|
||||
return self.inference(characters)
|
||||
raise RuntimeError(' [!] Set model training mode True or False')
|
||||
raise RuntimeError(" [!] Set model training mode True or False")
|
||||
|
||||
def training(self, characters, text_lengths, frames):
|
||||
B, T = shape_list(characters)
|
||||
embedding_vectors = self.embedding(characters, training=True)
|
||||
encoder_output = self.encoder(embedding_vectors, training=True)
|
||||
decoder_states = self.decoder.build_decoder_initial_states(B, 512, T)
|
||||
decoder_frames, stop_tokens, attentions = self.decoder(encoder_output, decoder_states, frames, text_lengths, training=True)
|
||||
decoder_frames, stop_tokens, attentions = self.decoder(
|
||||
encoder_output, decoder_states, frames, text_lengths, training=True
|
||||
)
|
||||
postnet_frames = self.postnet(decoder_frames, training=True)
|
||||
output_frames = decoder_frames + postnet_frames
|
||||
return decoder_frames, output_frames, attentions, stop_tokens
|
||||
|
@ -89,7 +95,8 @@ class Tacotron2(keras.models.Model):
|
|||
experimental_relax_shapes=True,
|
||||
input_signature=[
|
||||
tf.TensorSpec([1, None], dtype=tf.int32),
|
||||
],)
|
||||
],
|
||||
)
|
||||
def inference_tflite(self, characters):
|
||||
B, T = shape_list(characters)
|
||||
embedding_vectors = self.embedding(characters, training=False)
|
||||
|
@ -101,7 +108,9 @@ class Tacotron2(keras.models.Model):
|
|||
print(output_frames.shape)
|
||||
return decoder_frames, output_frames, attentions, stop_tokens
|
||||
|
||||
def build_inference(self, ):
|
||||
def build_inference(
|
||||
self,
|
||||
):
|
||||
# TODO: issue https://github.com/PyCQA/pylint/issues/3613
|
||||
input_ids = tf.random.uniform(shape=[1, 4], maxval=10, dtype=tf.int32) #pylint: disable=unexpected-keyword-arg
|
||||
input_ids = tf.random.uniform(shape=[1, 4], maxval=10, dtype=tf.int32) # pylint: disable=unexpected-keyword-arg
|
||||
self(input_ids)
|
||||
|
|
|
@ -2,8 +2,9 @@ import numpy as np
|
|||
import tensorflow as tf
|
||||
|
||||
# NOTE: linter has a problem with the current TF release
|
||||
#pylint: disable=no-value-for-parameter
|
||||
#pylint: disable=unexpected-keyword-arg
|
||||
# pylint: disable=no-value-for-parameter
|
||||
# pylint: disable=unexpected-keyword-arg
|
||||
|
||||
|
||||
def tf_create_dummy_inputs():
|
||||
""" Create dummy inputs for TF Tacotron2 model """
|
||||
|
@ -13,11 +14,11 @@ def tf_create_dummy_inputs():
|
|||
pad = 1
|
||||
n_chars = 24
|
||||
input_ids = tf.random.uniform([batch_size, max_input_length + pad], maxval=n_chars, dtype=tf.int32)
|
||||
input_lengths = np.random.randint(0, high=max_input_length+1 + pad, size=[batch_size])
|
||||
input_lengths = np.random.randint(0, high=max_input_length + 1 + pad, size=[batch_size])
|
||||
input_lengths[-1] = max_input_length
|
||||
input_lengths = tf.convert_to_tensor(input_lengths, dtype=tf.int32)
|
||||
mel_outputs = tf.random.uniform(shape=[batch_size, max_mel_length + pad, 80])
|
||||
mel_lengths = np.random.randint(0, high=max_mel_length+1 + pad, size=[batch_size])
|
||||
mel_lengths = np.random.randint(0, high=max_mel_length + 1 + pad, size=[batch_size])
|
||||
mel_lengths[-1] = max_mel_length
|
||||
mel_lengths = tf.convert_to_tensor(mel_lengths, dtype=tf.int32)
|
||||
return input_ids, input_lengths, mel_outputs, mel_lengths
|
||||
|
@ -31,14 +32,14 @@ def compare_torch_tf(torch_tensor, tf_tensor):
|
|||
def convert_tf_name(tf_name):
|
||||
""" Convert certain patterns in TF layer names to Torch patterns """
|
||||
tf_name_tmp = tf_name
|
||||
tf_name_tmp = tf_name_tmp.replace(':0', '')
|
||||
tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_1/recurrent_kernel', '/weight_hh_l0')
|
||||
tf_name_tmp = tf_name_tmp.replace('/forward_lstm/lstm_cell_2/kernel', '/weight_ih_l1')
|
||||
tf_name_tmp = tf_name_tmp.replace('/recurrent_kernel', '/weight_hh')
|
||||
tf_name_tmp = tf_name_tmp.replace('/kernel', '/weight')
|
||||
tf_name_tmp = tf_name_tmp.replace('/gamma', '/weight')
|
||||
tf_name_tmp = tf_name_tmp.replace('/beta', '/bias')
|
||||
tf_name_tmp = tf_name_tmp.replace('/', '.')
|
||||
tf_name_tmp = tf_name_tmp.replace(":0", "")
|
||||
tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_1/recurrent_kernel", "/weight_hh_l0")
|
||||
tf_name_tmp = tf_name_tmp.replace("/forward_lstm/lstm_cell_2/kernel", "/weight_ih_l1")
|
||||
tf_name_tmp = tf_name_tmp.replace("/recurrent_kernel", "/weight_hh")
|
||||
tf_name_tmp = tf_name_tmp.replace("/kernel", "/weight")
|
||||
tf_name_tmp = tf_name_tmp.replace("/gamma", "/weight")
|
||||
tf_name_tmp = tf_name_tmp.replace("/beta", "/bias")
|
||||
tf_name_tmp = tf_name_tmp.replace("/", ".")
|
||||
return tf_name_tmp
|
||||
|
||||
|
||||
|
@ -47,33 +48,35 @@ def transfer_weights_torch_to_tf(tf_vars, var_map_dict, state_dict):
|
|||
print(" > Passing weights from Torch to TF ...")
|
||||
for tf_var in tf_vars:
|
||||
torch_var_name = var_map_dict[tf_var.name]
|
||||
print(f' | > {tf_var.name} <-- {torch_var_name}')
|
||||
print(f" | > {tf_var.name} <-- {torch_var_name}")
|
||||
# if tuple, it is a bias variable
|
||||
if not isinstance(torch_var_name, tuple):
|
||||
torch_layer_name = '.'.join(torch_var_name.split('.')[-2:])
|
||||
torch_layer_name = ".".join(torch_var_name.split(".")[-2:])
|
||||
torch_weight = state_dict[torch_var_name]
|
||||
if 'convolution1d/kernel' in tf_var.name or 'conv1d/kernel' in tf_var.name:
|
||||
if "convolution1d/kernel" in tf_var.name or "conv1d/kernel" in tf_var.name:
|
||||
# out_dim, in_dim, filter -> filter, in_dim, out_dim
|
||||
numpy_weight = torch_weight.permute([2, 1, 0]).detach().cpu().numpy()
|
||||
elif 'lstm_cell' in tf_var.name and 'kernel' in tf_var.name:
|
||||
elif "lstm_cell" in tf_var.name and "kernel" in tf_var.name:
|
||||
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
|
||||
# if variable is for bidirectional lstm and it is a bias vector there
|
||||
# needs to be pre-defined two matching torch bias vectors
|
||||
elif '_lstm/lstm_cell_' in tf_var.name and 'bias' in tf_var.name:
|
||||
elif "_lstm/lstm_cell_" in tf_var.name and "bias" in tf_var.name:
|
||||
bias_vectors = [value for key, value in state_dict.items() if key in torch_var_name]
|
||||
assert len(bias_vectors) == 2
|
||||
numpy_weight = bias_vectors[0] + bias_vectors[1]
|
||||
elif 'rnn' in tf_var.name and 'kernel' in tf_var.name:
|
||||
elif "rnn" in tf_var.name and "kernel" in tf_var.name:
|
||||
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
|
||||
elif 'rnn' in tf_var.name and 'bias' in tf_var.name:
|
||||
elif "rnn" in tf_var.name and "bias" in tf_var.name:
|
||||
bias_vectors = [value for key, value in state_dict.items() if torch_var_name[:-2] in key]
|
||||
assert len(bias_vectors) == 2
|
||||
numpy_weight = bias_vectors[0] + bias_vectors[1]
|
||||
elif 'linear_layer' in torch_layer_name and 'weight' in torch_var_name:
|
||||
elif "linear_layer" in torch_layer_name and "weight" in torch_var_name:
|
||||
numpy_weight = torch_weight.transpose(0, 1).detach().cpu().numpy()
|
||||
else:
|
||||
numpy_weight = torch_weight.detach().cpu().numpy()
|
||||
assert np.all(tf_var.shape == numpy_weight.shape), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}"
|
||||
assert np.all(
|
||||
tf_var.shape == numpy_weight.shape
|
||||
), f" [!] weight shapes does not match: {tf_var.name} vs {torch_var_name} --> {tf_var.shape} vs {numpy_weight.shape}"
|
||||
tf.keras.backend.set_value(tf_var, numpy_weight)
|
||||
return tf_vars
|
||||
|
||||
|
|
|
@ -7,20 +7,20 @@ import tensorflow as tf
|
|||
|
||||
def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs):
|
||||
state = {
|
||||
'model': model.weights,
|
||||
'optimizer': optimizer,
|
||||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||
'r': r
|
||||
"model": model.weights,
|
||||
"optimizer": optimizer,
|
||||
"step": current_step,
|
||||
"epoch": epoch,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
"r": r,
|
||||
}
|
||||
state.update(kwargs)
|
||||
pickle.dump(state, open(output_path, 'wb'))
|
||||
pickle.dump(state, open(output_path, "wb"))
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path):
|
||||
checkpoint = pickle.load(open(checkpoint_path, 'rb'))
|
||||
chkp_var_dict = {var.name: var.numpy() for var in checkpoint['model']}
|
||||
checkpoint = pickle.load(open(checkpoint_path, "rb"))
|
||||
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
||||
tf_vars = model.weights
|
||||
for tf_var in tf_vars:
|
||||
layer_name = tf_var.name
|
||||
|
@ -32,8 +32,8 @@ def load_checkpoint(model, checkpoint_path):
|
|||
chkp_var_value = chkp_var_dict[layer_name]
|
||||
|
||||
tf.keras.backend.set_value(tf_var, chkp_var_value)
|
||||
if 'r' in checkpoint.keys():
|
||||
model.decoder.set_r(checkpoint['r'])
|
||||
if "r" in checkpoint.keys():
|
||||
model.decoder.set_r(checkpoint["r"])
|
||||
return model
|
||||
|
||||
|
||||
|
@ -45,8 +45,7 @@ def sequence_mask(sequence_length, max_len=None):
|
|||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
if sequence_length.is_cuda:
|
||||
seq_range_expand = seq_range_expand.cuda()
|
||||
seq_length_expand = (
|
||||
sequence_length.unsqueeze(1).expand_as(seq_range_expand))
|
||||
seq_length_expand = sequence_length.unsqueeze(1).expand_as(seq_range_expand)
|
||||
# B x T_max
|
||||
return seq_range_expand < seq_length_expand
|
||||
|
||||
|
@ -62,42 +61,42 @@ def count_parameters(model, c):
|
|||
try:
|
||||
return model.count_params()
|
||||
except RuntimeError:
|
||||
input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype('int32'))
|
||||
input_lengths = np.random.randint(100, 129, (8, ))
|
||||
input_dummy = tf.convert_to_tensor(np.random.rand(8, 128).astype("int32"))
|
||||
input_lengths = np.random.randint(100, 129, (8,))
|
||||
input_lengths[-1] = 128
|
||||
input_lengths = tf.convert_to_tensor(input_lengths.astype('int32'))
|
||||
mel_spec = np.random.rand(8, 2 * c.r,
|
||||
c.audio['num_mels']).astype('float32')
|
||||
input_lengths = tf.convert_to_tensor(input_lengths.astype("int32"))
|
||||
mel_spec = np.random.rand(8, 2 * c.r, c.audio["num_mels"]).astype("float32")
|
||||
mel_spec = tf.convert_to_tensor(mel_spec)
|
||||
speaker_ids = np.random.randint(
|
||||
0, 5, (8, )) if c.use_speaker_embedding else None
|
||||
speaker_ids = np.random.randint(0, 5, (8,)) if c.use_speaker_embedding else None
|
||||
_ = model(input_dummy, input_lengths, mel_spec, speaker_ids=speaker_ids)
|
||||
return model.count_params()
|
||||
|
||||
|
||||
def setup_model(num_chars, num_speakers, c, enable_tflite=False):
|
||||
print(" > Using model: {}".format(c.model))
|
||||
MyModel = importlib.import_module('TTS.tts.tf.models.' + c.model.lower())
|
||||
MyModel = importlib.import_module("TTS.tts.tf.models." + c.model.lower())
|
||||
MyModel = getattr(MyModel, c.model)
|
||||
if c.model.lower() in "tacotron":
|
||||
raise NotImplementedError(' [!] Tacotron model is not ready.')
|
||||
raise NotImplementedError(" [!] Tacotron model is not ready.")
|
||||
# tacotron2
|
||||
model = MyModel(num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
postnet_output_dim=c.audio['num_mels'],
|
||||
decoder_output_dim=c.audio['num_mels'],
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder,
|
||||
enable_tflite=enable_tflite)
|
||||
model = MyModel(
|
||||
num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
postnet_output_dim=c.audio["num_mels"],
|
||||
decoder_output_dim=c.audio["num_mels"],
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder,
|
||||
enable_tflite=enable_tflite,
|
||||
)
|
||||
return model
|
||||
|
|
|
@ -5,20 +5,20 @@ import tensorflow as tf
|
|||
|
||||
def save_checkpoint(model, optimizer, current_step, epoch, r, output_path, **kwargs):
|
||||
state = {
|
||||
'model': model.weights,
|
||||
'optimizer': optimizer,
|
||||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||
'r': r
|
||||
"model": model.weights,
|
||||
"optimizer": optimizer,
|
||||
"step": current_step,
|
||||
"epoch": epoch,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
"r": r,
|
||||
}
|
||||
state.update(kwargs)
|
||||
pickle.dump(state, open(output_path, 'wb'))
|
||||
pickle.dump(state, open(output_path, "wb"))
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path):
|
||||
checkpoint = pickle.load(open(checkpoint_path, 'rb'))
|
||||
chkp_var_dict = {var.name: var.numpy() for var in checkpoint['model']}
|
||||
checkpoint = pickle.load(open(checkpoint_path, "rb"))
|
||||
chkp_var_dict = {var.name: var.numpy() for var in checkpoint["model"]}
|
||||
tf_vars = model.weights
|
||||
for tf_var in tf_vars:
|
||||
layer_name = tf_var.name
|
||||
|
@ -30,8 +30,8 @@ def load_checkpoint(model, checkpoint_path):
|
|||
chkp_var_value = chkp_var_dict[layer_name]
|
||||
|
||||
tf.keras.backend.set_value(tf_var, chkp_var_value)
|
||||
if 'r' in checkpoint.keys():
|
||||
model.decoder.set_r(checkpoint['r'])
|
||||
if "r" in checkpoint.keys():
|
||||
model.decoder.set_r(checkpoint["r"])
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -1,25 +1,20 @@
|
|||
import tensorflow as tf
|
||||
|
||||
|
||||
def convert_tacotron2_to_tflite(model,
|
||||
output_path=None,
|
||||
experimental_converter=True):
|
||||
def convert_tacotron2_to_tflite(model, output_path=None, experimental_converter=True):
|
||||
"""Convert Tensorflow Tacotron2 model to TFLite. Save a binary file if output_path is
|
||||
provided, else return TFLite model."""
|
||||
|
||||
concrete_function = model.inference_tflite.get_concrete_function()
|
||||
converter = tf.lite.TFLiteConverter.from_concrete_functions(
|
||||
[concrete_function])
|
||||
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_function])
|
||||
converter.experimental_new_converter = experimental_converter
|
||||
converter.optimizations = [tf.lite.Optimize.DEFAULT]
|
||||
converter.target_spec.supported_ops = [
|
||||
tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
|
||||
]
|
||||
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
|
||||
tflite_model = converter.convert()
|
||||
print(f'Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.')
|
||||
print(f"Tflite Model size is {len(tflite_model) / (1024.0 * 1024.0)} MBs.")
|
||||
if output_path is not None:
|
||||
# same model binary if outputpath is provided
|
||||
with open(output_path, 'wb') as f:
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(tflite_model)
|
||||
return None
|
||||
return tflite_model
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
@ -31,38 +30,37 @@ def _num2chinese(num: str, big=False, simp=True, o=False, twoalt=False) -> str:
|
|||
# check num first
|
||||
nd = str(num)
|
||||
if abs(float(nd)) >= 1e48:
|
||||
raise ValueError('number out of range')
|
||||
if 'e' in nd:
|
||||
raise ValueError('scientific notation is not supported')
|
||||
c_symbol = '正负点' if simp else '正負點'
|
||||
raise ValueError("number out of range")
|
||||
if "e" in nd:
|
||||
raise ValueError("scientific notation is not supported")
|
||||
c_symbol = "正负点" if simp else "正負點"
|
||||
if o: # formal
|
||||
twoalt = False
|
||||
if big:
|
||||
c_basic = '零壹贰叁肆伍陆柒捌玖' if simp else '零壹貳參肆伍陸柒捌玖'
|
||||
c_unit1 = '拾佰仟'
|
||||
c_twoalt = '贰' if simp else '貳'
|
||||
c_basic = "零壹贰叁肆伍陆柒捌玖" if simp else "零壹貳參肆伍陸柒捌玖"
|
||||
c_unit1 = "拾佰仟"
|
||||
c_twoalt = "贰" if simp else "貳"
|
||||
else:
|
||||
c_basic = '〇一二三四五六七八九' if o else '零一二三四五六七八九'
|
||||
c_unit1 = '十百千'
|
||||
c_basic = "〇一二三四五六七八九" if o else "零一二三四五六七八九"
|
||||
c_unit1 = "十百千"
|
||||
if twoalt:
|
||||
c_twoalt = '两' if simp else '兩'
|
||||
c_twoalt = "两" if simp else "兩"
|
||||
else:
|
||||
c_twoalt = '二'
|
||||
c_unit2 = '万亿兆京垓秭穰沟涧正载' if simp else '萬億兆京垓秭穰溝澗正載'
|
||||
revuniq = lambda l: ''.join(k for k, g in itertools.groupby(reversed(l)))
|
||||
c_twoalt = "二"
|
||||
c_unit2 = "万亿兆京垓秭穰沟涧正载" if simp else "萬億兆京垓秭穰溝澗正載"
|
||||
revuniq = lambda l: "".join(k for k, g in itertools.groupby(reversed(l)))
|
||||
nd = str(num)
|
||||
result = []
|
||||
if nd[0] == '+':
|
||||
if nd[0] == "+":
|
||||
result.append(c_symbol[0])
|
||||
elif nd[0] == '-':
|
||||
elif nd[0] == "-":
|
||||
result.append(c_symbol[1])
|
||||
if '.' in nd:
|
||||
integer, remainder = nd.lstrip('+-').split('.')
|
||||
if "." in nd:
|
||||
integer, remainder = nd.lstrip("+-").split(".")
|
||||
else:
|
||||
integer, remainder = nd.lstrip('+-'), None
|
||||
integer, remainder = nd.lstrip("+-"), None
|
||||
if int(integer):
|
||||
splitted = [integer[max(i - 4, 0):i]
|
||||
for i in range(len(integer), 0, -4)]
|
||||
splitted = [integer[max(i - 4, 0) : i] for i in range(len(integer), 0, -4)]
|
||||
intresult = []
|
||||
for nu, unit in enumerate(splitted):
|
||||
# special cases
|
||||
|
@ -75,17 +73,17 @@ def _num2chinese(num: str, big=False, simp=True, o=False, twoalt=False) -> str:
|
|||
ulist = []
|
||||
unit = unit.zfill(4)
|
||||
for nc, ch in enumerate(reversed(unit)):
|
||||
if ch == '0':
|
||||
if ch == "0":
|
||||
if ulist: # ???0
|
||||
ulist.append(c_basic[0])
|
||||
elif nc == 0:
|
||||
ulist.append(c_basic[int(ch)])
|
||||
elif nc == 1 and ch == '1' and unit[1] == '0':
|
||||
elif nc == 1 and ch == "1" and unit[1] == "0":
|
||||
# special case for tens
|
||||
# edit the 'elif' if you don't like
|
||||
# 十四, 三千零十四, 三千三百一十四
|
||||
ulist.append(c_unit1[0])
|
||||
elif nc > 1 and ch == '2':
|
||||
elif nc > 1 and ch == "2":
|
||||
ulist.append(c_twoalt + c_unit1[nc - 1])
|
||||
else:
|
||||
ulist.append(c_basic[int(ch)] + c_unit1[nc - 1])
|
||||
|
@ -99,10 +97,8 @@ def _num2chinese(num: str, big=False, simp=True, o=False, twoalt=False) -> str:
|
|||
result.append(c_basic[0])
|
||||
if remainder:
|
||||
result.append(c_symbol[2])
|
||||
result.append(''.join(c_basic[int(ch)] for ch in remainder))
|
||||
return ''.join(result)
|
||||
|
||||
|
||||
result.append("".join(c_basic[int(ch)] for ch in remainder))
|
||||
return "".join(result)
|
||||
|
||||
|
||||
def _number_replace(match) -> str:
|
||||
|
@ -127,5 +123,5 @@ def replace_numbers_to_characters_in_text(text: str) -> str:
|
|||
Returns:
|
||||
str: output text
|
||||
"""
|
||||
text = re.sub(r'[0-9]+', _number_replace, text)
|
||||
text = re.sub(r"[0-9]+", _number_replace, text)
|
||||
return text
|
||||
|
|
|
@ -9,9 +9,7 @@ import jieba
|
|||
|
||||
|
||||
def _chinese_character_to_pinyin(text: str) -> List[str]:
|
||||
pinyins = pypinyin.pinyin(
|
||||
text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True
|
||||
)
|
||||
pinyins = pypinyin.pinyin(text, style=pypinyin.Style.TONE3, heteronym=False, neutral_tone_with_five=True)
|
||||
pinyins_flat_list = [item for sublist in pinyins for item in sublist]
|
||||
return pinyins_flat_list
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
PINYIN_DICT = {
|
||||
"a": ["a"],
|
||||
"ai": ["ai"],
|
||||
|
|
|
@ -4,8 +4,7 @@ import numpy as np
|
|||
def _pad_data(x, length):
|
||||
_pad = 0
|
||||
assert x.ndim == 1
|
||||
return np.pad(
|
||||
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
|
||||
return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=_pad)
|
||||
|
||||
|
||||
def prepare_data(inputs):
|
||||
|
@ -14,12 +13,9 @@ def prepare_data(inputs):
|
|||
|
||||
|
||||
def _pad_tensor(x, length):
|
||||
_pad = 0.
|
||||
_pad = 0.0
|
||||
assert x.ndim == 2
|
||||
x = np.pad(
|
||||
x, [[0, 0], [0, length - x.shape[1]]],
|
||||
mode='constant',
|
||||
constant_values=_pad)
|
||||
x = np.pad(x, [[0, 0], [0, length - x.shape[1]]], mode="constant", constant_values=_pad)
|
||||
return x
|
||||
|
||||
|
||||
|
@ -31,10 +27,9 @@ def prepare_tensor(inputs, out_steps):
|
|||
|
||||
|
||||
def _pad_stop_target(x, length):
|
||||
_pad = 0.
|
||||
_pad = 0.0
|
||||
assert x.ndim == 1
|
||||
return np.pad(
|
||||
x, (0, length - x.shape[0]), mode='constant', constant_values=_pad)
|
||||
return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=_pad)
|
||||
|
||||
|
||||
def prepare_stop_target(inputs, out_steps):
|
||||
|
@ -46,22 +41,18 @@ def prepare_stop_target(inputs, out_steps):
|
|||
|
||||
|
||||
def pad_per_step(inputs, pad_len):
|
||||
return np.pad(
|
||||
inputs, [[0, 0], [0, 0], [0, pad_len]],
|
||||
mode='constant',
|
||||
constant_values=0.0)
|
||||
return np.pad(inputs, [[0, 0], [0, 0], [0, pad_len]], mode="constant", constant_values=0.0)
|
||||
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
class StandardScaler():
|
||||
|
||||
class StandardScaler:
|
||||
def set_stats(self, mean, scale):
|
||||
self.mean_ = mean
|
||||
self.scale_ = scale
|
||||
|
||||
def reset_stats(self):
|
||||
delattr(self, 'mean_')
|
||||
delattr(self, 'scale_')
|
||||
delattr(self, "mean_")
|
||||
delattr(self, "scale_")
|
||||
|
||||
def transform(self, X):
|
||||
X = np.asarray(X)
|
||||
|
|
|
@ -28,277 +28,334 @@ def split_dataset(items):
|
|||
return items_eval, items
|
||||
return items[:eval_split_size], items[eval_split_size:]
|
||||
|
||||
|
||||
# from https://gist.github.com/jihunchoi/f1434a77df9db1bb337417854b398df1
|
||||
def sequence_mask(sequence_length, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = sequence_length.data.max()
|
||||
seq_range = torch.arange(max_len,
|
||||
dtype=sequence_length.dtype,
|
||||
device=sequence_length.device)
|
||||
seq_range = torch.arange(max_len, dtype=sequence_length.dtype, device=sequence_length.device)
|
||||
# B x T_max
|
||||
return seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
|
||||
|
||||
|
||||
def to_camel(text):
|
||||
text = text.capitalize()
|
||||
text = re.sub(r'(?!^)_([a-zA-Z])', lambda m: m.group(1).upper(), text)
|
||||
text = text.replace('Tts', 'TTS')
|
||||
text = re.sub(r"(?!^)_([a-zA-Z])", lambda m: m.group(1).upper(), text)
|
||||
text = text.replace("Tts", "TTS")
|
||||
return text
|
||||
|
||||
|
||||
def setup_model(num_chars, num_speakers, c, speaker_embedding_dim=None):
|
||||
print(" > Using model: {}".format(c.model))
|
||||
MyModel = importlib.import_module('TTS.tts.models.' + c.model.lower())
|
||||
MyModel = importlib.import_module("TTS.tts.models." + c.model.lower())
|
||||
MyModel = getattr(MyModel, to_camel(c.model))
|
||||
if c.model.lower() in "tacotron":
|
||||
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
postnet_output_dim=int(c.audio['fft_size'] / 2 + 1),
|
||||
decoder_output_dim=c.audio['num_mels'],
|
||||
gst=c.use_gst,
|
||||
gst_embedding_dim=c.gst['gst_embedding_dim'],
|
||||
gst_num_heads=c.gst['gst_num_heads'],
|
||||
gst_style_tokens=c.gst['gst_style_tokens'],
|
||||
gst_use_speaker_embedding=c.gst['gst_use_speaker_embedding'],
|
||||
memory_size=c.memory_size,
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder,
|
||||
double_decoder_consistency=c.double_decoder_consistency,
|
||||
ddc_r=c.ddc_r,
|
||||
speaker_embedding_dim=speaker_embedding_dim)
|
||||
model = MyModel(
|
||||
num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
postnet_output_dim=int(c.audio["fft_size"] / 2 + 1),
|
||||
decoder_output_dim=c.audio["num_mels"],
|
||||
gst=c.use_gst,
|
||||
gst_embedding_dim=c.gst["gst_embedding_dim"],
|
||||
gst_num_heads=c.gst["gst_num_heads"],
|
||||
gst_style_tokens=c.gst["gst_style_tokens"],
|
||||
gst_use_speaker_embedding=c.gst["gst_use_speaker_embedding"],
|
||||
memory_size=c.memory_size,
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder,
|
||||
double_decoder_consistency=c.double_decoder_consistency,
|
||||
ddc_r=c.ddc_r,
|
||||
speaker_embedding_dim=speaker_embedding_dim,
|
||||
)
|
||||
elif c.model.lower() == "tacotron2":
|
||||
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
postnet_output_dim=c.audio['num_mels'],
|
||||
decoder_output_dim=c.audio['num_mels'],
|
||||
gst=c.use_gst,
|
||||
gst_embedding_dim=c.gst['gst_embedding_dim'],
|
||||
gst_num_heads=c.gst['gst_num_heads'],
|
||||
gst_style_tokens=c.gst['gst_style_tokens'],
|
||||
gst_use_speaker_embedding=c.gst['gst_use_speaker_embedding'],
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder,
|
||||
double_decoder_consistency=c.double_decoder_consistency,
|
||||
ddc_r=c.ddc_r,
|
||||
speaker_embedding_dim=speaker_embedding_dim)
|
||||
model = MyModel(
|
||||
num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
postnet_output_dim=c.audio["num_mels"],
|
||||
decoder_output_dim=c.audio["num_mels"],
|
||||
gst=c.use_gst,
|
||||
gst_embedding_dim=c.gst["gst_embedding_dim"],
|
||||
gst_num_heads=c.gst["gst_num_heads"],
|
||||
gst_style_tokens=c.gst["gst_style_tokens"],
|
||||
gst_use_speaker_embedding=c.gst["gst_use_speaker_embedding"],
|
||||
attn_type=c.attention_type,
|
||||
attn_win=c.windowing,
|
||||
attn_norm=c.attention_norm,
|
||||
prenet_type=c.prenet_type,
|
||||
prenet_dropout=c.prenet_dropout,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
attn_K=c.attention_heads,
|
||||
separate_stopnet=c.separate_stopnet,
|
||||
bidirectional_decoder=c.bidirectional_decoder,
|
||||
double_decoder_consistency=c.double_decoder_consistency,
|
||||
ddc_r=c.ddc_r,
|
||||
speaker_embedding_dim=speaker_embedding_dim,
|
||||
)
|
||||
elif c.model.lower() == "glow_tts":
|
||||
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
hidden_channels_enc=c['hidden_channels_encoder'],
|
||||
hidden_channels_dec=c['hidden_channels_decoder'],
|
||||
hidden_channels_dp=c['hidden_channels_duration_predictor'],
|
||||
out_channels=c.audio['num_mels'],
|
||||
encoder_type=c.encoder_type,
|
||||
encoder_params=c.encoder_params,
|
||||
use_encoder_prenet=c["use_encoder_prenet"],
|
||||
num_flow_blocks_dec=12,
|
||||
kernel_size_dec=5,
|
||||
dilation_rate=1,
|
||||
num_block_layers=4,
|
||||
dropout_p_dec=0.05,
|
||||
num_speakers=num_speakers,
|
||||
c_in_channels=0,
|
||||
num_splits=4,
|
||||
num_squeeze=2,
|
||||
sigmoid_scale=False,
|
||||
mean_only=True,
|
||||
external_speaker_embedding_dim=speaker_embedding_dim)
|
||||
model = MyModel(
|
||||
num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
hidden_channels_enc=c["hidden_channels_encoder"],
|
||||
hidden_channels_dec=c["hidden_channels_decoder"],
|
||||
hidden_channels_dp=c["hidden_channels_duration_predictor"],
|
||||
out_channels=c.audio["num_mels"],
|
||||
encoder_type=c.encoder_type,
|
||||
encoder_params=c.encoder_params,
|
||||
use_encoder_prenet=c["use_encoder_prenet"],
|
||||
num_flow_blocks_dec=12,
|
||||
kernel_size_dec=5,
|
||||
dilation_rate=1,
|
||||
num_block_layers=4,
|
||||
dropout_p_dec=0.05,
|
||||
num_speakers=num_speakers,
|
||||
c_in_channels=0,
|
||||
num_splits=4,
|
||||
num_squeeze=2,
|
||||
sigmoid_scale=False,
|
||||
mean_only=True,
|
||||
external_speaker_embedding_dim=speaker_embedding_dim,
|
||||
)
|
||||
elif c.model.lower() == "speedy_speech":
|
||||
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
out_channels=c.audio['num_mels'],
|
||||
hidden_channels=c['hidden_channels'],
|
||||
positional_encoding=c['positional_encoding'],
|
||||
encoder_type=c['encoder_type'],
|
||||
encoder_params=c['encoder_params'],
|
||||
decoder_type=c['decoder_type'],
|
||||
decoder_params=c['decoder_params'],
|
||||
c_in_channels=0)
|
||||
model = MyModel(
|
||||
num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
out_channels=c.audio["num_mels"],
|
||||
hidden_channels=c["hidden_channels"],
|
||||
positional_encoding=c["positional_encoding"],
|
||||
encoder_type=c["encoder_type"],
|
||||
encoder_params=c["encoder_params"],
|
||||
decoder_type=c["decoder_type"],
|
||||
decoder_params=c["decoder_params"],
|
||||
c_in_channels=0,
|
||||
)
|
||||
elif c.model.lower() == "align_tts":
|
||||
model = MyModel(num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
out_channels=c.audio['num_mels'],
|
||||
hidden_channels=c['hidden_channels'],
|
||||
hidden_channels_dp=c['hidden_channels_dp'],
|
||||
encoder_type=c['encoder_type'],
|
||||
encoder_params=c['encoder_params'],
|
||||
decoder_type=c['decoder_type'],
|
||||
decoder_params=c['decoder_params'],
|
||||
c_in_channels=0)
|
||||
model = MyModel(
|
||||
num_chars=num_chars + getattr(c, "add_blank", False),
|
||||
out_channels=c.audio["num_mels"],
|
||||
hidden_channels=c["hidden_channels"],
|
||||
hidden_channels_dp=c["hidden_channels_dp"],
|
||||
encoder_type=c["encoder_type"],
|
||||
encoder_params=c["encoder_params"],
|
||||
decoder_type=c["decoder_type"],
|
||||
decoder_params=c["decoder_params"],
|
||||
c_in_channels=0,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def is_tacotron(c):
|
||||
return 'tacotron' in c['model'].lower()
|
||||
return "tacotron" in c["model"].lower()
|
||||
|
||||
|
||||
def check_config_tts(c):
|
||||
check_argument('model', c, enum_list=['tacotron', 'tacotron2', 'glow_tts', 'speedy_speech', 'align_tts'], restricted=True, val_type=str)
|
||||
check_argument('run_name', c, restricted=True, val_type=str)
|
||||
check_argument('run_description', c, val_type=str)
|
||||
check_argument(
|
||||
"model",
|
||||
c,
|
||||
enum_list=["tacotron", "tacotron2", "glow_tts", "speedy_speech", "align_tts"],
|
||||
restricted=True,
|
||||
val_type=str,
|
||||
)
|
||||
check_argument("run_name", c, restricted=True, val_type=str)
|
||||
check_argument("run_description", c, val_type=str)
|
||||
|
||||
# AUDIO
|
||||
check_argument('audio', c, restricted=True, val_type=dict)
|
||||
check_argument("audio", c, restricted=True, val_type=dict)
|
||||
|
||||
# audio processing parameters
|
||||
check_argument('num_mels', c['audio'], restricted=True, val_type=int, min_val=10, max_val=2056)
|
||||
check_argument('fft_size', c['audio'], restricted=True, val_type=int, min_val=128, max_val=4058)
|
||||
check_argument('sample_rate', c['audio'], restricted=True, val_type=int, min_val=512, max_val=100000)
|
||||
check_argument('frame_length_ms', c['audio'], restricted=True, val_type=float, min_val=10, max_val=1000, alternative='win_length')
|
||||
check_argument('frame_shift_ms', c['audio'], restricted=True, val_type=float, min_val=1, max_val=1000, alternative='hop_length')
|
||||
check_argument('preemphasis', c['audio'], restricted=True, val_type=float, min_val=0, max_val=1)
|
||||
check_argument('min_level_db', c['audio'], restricted=True, val_type=int, min_val=-1000, max_val=10)
|
||||
check_argument('ref_level_db', c['audio'], restricted=True, val_type=int, min_val=0, max_val=1000)
|
||||
check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5)
|
||||
check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000)
|
||||
check_argument("num_mels", c["audio"], restricted=True, val_type=int, min_val=10, max_val=2056)
|
||||
check_argument("fft_size", c["audio"], restricted=True, val_type=int, min_val=128, max_val=4058)
|
||||
check_argument("sample_rate", c["audio"], restricted=True, val_type=int, min_val=512, max_val=100000)
|
||||
check_argument(
|
||||
"frame_length_ms",
|
||||
c["audio"],
|
||||
restricted=True,
|
||||
val_type=float,
|
||||
min_val=10,
|
||||
max_val=1000,
|
||||
alternative="win_length",
|
||||
)
|
||||
check_argument(
|
||||
"frame_shift_ms", c["audio"], restricted=True, val_type=float, min_val=1, max_val=1000, alternative="hop_length"
|
||||
)
|
||||
check_argument("preemphasis", c["audio"], restricted=True, val_type=float, min_val=0, max_val=1)
|
||||
check_argument("min_level_db", c["audio"], restricted=True, val_type=int, min_val=-1000, max_val=10)
|
||||
check_argument("ref_level_db", c["audio"], restricted=True, val_type=int, min_val=0, max_val=1000)
|
||||
check_argument("power", c["audio"], restricted=True, val_type=float, min_val=1, max_val=5)
|
||||
check_argument("griffin_lim_iters", c["audio"], restricted=True, val_type=int, min_val=10, max_val=1000)
|
||||
|
||||
# vocabulary parameters
|
||||
check_argument('characters', c, restricted=False, val_type=dict)
|
||||
check_argument('pad', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
check_argument('eos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
check_argument('bos', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
check_argument('characters', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
check_argument('phonemes', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys() and c['use_phonemes'], val_type=str)
|
||||
check_argument('punctuations', c['characters'] if 'characters' in c.keys() else {}, restricted='characters' in c.keys(), val_type=str)
|
||||
check_argument("characters", c, restricted=False, val_type=dict)
|
||||
check_argument(
|
||||
"pad", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str
|
||||
)
|
||||
check_argument(
|
||||
"eos", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str
|
||||
)
|
||||
check_argument(
|
||||
"bos", c["characters"] if "characters" in c.keys() else {}, restricted="characters" in c.keys(), val_type=str
|
||||
)
|
||||
check_argument(
|
||||
"characters",
|
||||
c["characters"] if "characters" in c.keys() else {},
|
||||
restricted="characters" in c.keys(),
|
||||
val_type=str,
|
||||
)
|
||||
check_argument(
|
||||
"phonemes",
|
||||
c["characters"] if "characters" in c.keys() else {},
|
||||
restricted="characters" in c.keys() and c["use_phonemes"],
|
||||
val_type=str,
|
||||
)
|
||||
check_argument(
|
||||
"punctuations",
|
||||
c["characters"] if "characters" in c.keys() else {},
|
||||
restricted="characters" in c.keys(),
|
||||
val_type=str,
|
||||
)
|
||||
|
||||
# normalization parameters
|
||||
check_argument('signal_norm', c['audio'], restricted=True, val_type=bool)
|
||||
check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool)
|
||||
check_argument('max_norm', c['audio'], restricted=True, val_type=float, min_val=0.1, max_val=1000)
|
||||
check_argument('clip_norm', c['audio'], restricted=True, val_type=bool)
|
||||
check_argument('mel_fmin', c['audio'], restricted=True, val_type=float, min_val=0.0, max_val=1000)
|
||||
check_argument('mel_fmax', c['audio'], restricted=True, val_type=float, min_val=500.0)
|
||||
check_argument('spec_gain', c['audio'], restricted=True, val_type=[int, float], min_val=1, max_val=100)
|
||||
check_argument('do_trim_silence', c['audio'], restricted=True, val_type=bool)
|
||||
check_argument('trim_db', c['audio'], restricted=True, val_type=int)
|
||||
check_argument("signal_norm", c["audio"], restricted=True, val_type=bool)
|
||||
check_argument("symmetric_norm", c["audio"], restricted=True, val_type=bool)
|
||||
check_argument("max_norm", c["audio"], restricted=True, val_type=float, min_val=0.1, max_val=1000)
|
||||
check_argument("clip_norm", c["audio"], restricted=True, val_type=bool)
|
||||
check_argument("mel_fmin", c["audio"], restricted=True, val_type=float, min_val=0.0, max_val=1000)
|
||||
check_argument("mel_fmax", c["audio"], restricted=True, val_type=float, min_val=500.0)
|
||||
check_argument("spec_gain", c["audio"], restricted=True, val_type=[int, float], min_val=1, max_val=100)
|
||||
check_argument("do_trim_silence", c["audio"], restricted=True, val_type=bool)
|
||||
check_argument("trim_db", c["audio"], restricted=True, val_type=int)
|
||||
|
||||
# training parameters
|
||||
check_argument('batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument('r', c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument('gradual_training', c, restricted=False, val_type=list)
|
||||
check_argument('mixed_precision', c, restricted=False, val_type=bool)
|
||||
check_argument("batch_size", c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument("eval_batch_size", c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument("r", c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument("gradual_training", c, restricted=False, val_type=list)
|
||||
check_argument("mixed_precision", c, restricted=False, val_type=bool)
|
||||
# check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100)
|
||||
|
||||
# loss parameters
|
||||
check_argument('loss_masking', c, restricted=True, val_type=bool)
|
||||
if c['model'].lower() in ['tacotron', 'tacotron2']:
|
||||
check_argument('decoder_loss_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('postnet_loss_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('postnet_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('decoder_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
if c['model'].lower in ["speedy_speech", "align_tts"]:
|
||||
check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument("loss_masking", c, restricted=True, val_type=bool)
|
||||
if c["model"].lower() in ["tacotron", "tacotron2"]:
|
||||
check_argument("decoder_loss_alpha", c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument("postnet_loss_alpha", c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument("postnet_diff_spec_alpha", c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument("decoder_diff_spec_alpha", c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument("decoder_ssim_alpha", c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument("postnet_ssim_alpha", c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument("ga_alpha", c, restricted=True, val_type=float, min_val=0)
|
||||
if c["model"].lower in ["speedy_speech", "align_tts"]:
|
||||
check_argument("ssim_alpha", c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument("l1_alpha", c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument("huber_alpha", c, restricted=True, val_type=float, min_val=0)
|
||||
|
||||
# validation parameters
|
||||
check_argument('run_eval', c, restricted=True, val_type=bool)
|
||||
check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument('test_sentences_file', c, restricted=False, val_type=str)
|
||||
check_argument("run_eval", c, restricted=True, val_type=bool)
|
||||
check_argument("test_delay_epochs", c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument("test_sentences_file", c, restricted=False, val_type=str)
|
||||
|
||||
# optimizer
|
||||
check_argument('noam_schedule', c, restricted=False, val_type=bool)
|
||||
check_argument('grad_clip', c, restricted=True, val_type=float, min_val=0.0)
|
||||
check_argument('epochs', c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument('lr', c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument('wd', c, restricted=is_tacotron(c), val_type=float, min_val=0)
|
||||
check_argument('warmup_steps', c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument('seq_len_norm', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("noam_schedule", c, restricted=False, val_type=bool)
|
||||
check_argument("grad_clip", c, restricted=True, val_type=float, min_val=0.0)
|
||||
check_argument("epochs", c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument("lr", c, restricted=True, val_type=float, min_val=0)
|
||||
check_argument("wd", c, restricted=is_tacotron(c), val_type=float, min_val=0)
|
||||
check_argument("warmup_steps", c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument("seq_len_norm", c, restricted=is_tacotron(c), val_type=bool)
|
||||
|
||||
# tacotron prenet
|
||||
check_argument('memory_size', c, restricted=is_tacotron(c), val_type=int, min_val=-1)
|
||||
check_argument('prenet_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['original', 'bn'])
|
||||
check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("memory_size", c, restricted=is_tacotron(c), val_type=int, min_val=-1)
|
||||
check_argument("prenet_type", c, restricted=is_tacotron(c), val_type=str, enum_list=["original", "bn"])
|
||||
check_argument("prenet_dropout", c, restricted=is_tacotron(c), val_type=bool)
|
||||
|
||||
# attention
|
||||
check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original', 'dynamic_convolution'])
|
||||
check_argument('attention_heads', c, restricted=is_tacotron(c), val_type=int)
|
||||
check_argument('attention_norm', c, restricted=is_tacotron(c), val_type=str, enum_list=['sigmoid', 'softmax'])
|
||||
check_argument('windowing', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('use_forward_attn', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('forward_attn_mask', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('transition_agent', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('location_attn', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('bidirectional_decoder', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('double_decoder_consistency', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('ddc_r', c, restricted='double_decoder_consistency' in c.keys(), min_val=1, max_val=7, val_type=int)
|
||||
check_argument(
|
||||
"attention_type",
|
||||
c,
|
||||
restricted=is_tacotron(c),
|
||||
val_type=str,
|
||||
enum_list=["graves", "original", "dynamic_convolution"],
|
||||
)
|
||||
check_argument("attention_heads", c, restricted=is_tacotron(c), val_type=int)
|
||||
check_argument("attention_norm", c, restricted=is_tacotron(c), val_type=str, enum_list=["sigmoid", "softmax"])
|
||||
check_argument("windowing", c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("use_forward_attn", c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("forward_attn_mask", c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("transition_agent", c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("transition_agent", c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("location_attn", c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("bidirectional_decoder", c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("double_decoder_consistency", c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("ddc_r", c, restricted="double_decoder_consistency" in c.keys(), min_val=1, max_val=7, val_type=int)
|
||||
|
||||
if c['model'].lower() in ['tacotron', 'tacotron2']:
|
||||
if c["model"].lower() in ["tacotron", "tacotron2"]:
|
||||
# stopnet
|
||||
check_argument('stopnet', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('separate_stopnet', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("stopnet", c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("separate_stopnet", c, restricted=is_tacotron(c), val_type=bool)
|
||||
|
||||
# Model Parameters for non-tacotron models
|
||||
if c['model'].lower in ["speedy_speech", "align_tts"]:
|
||||
check_argument('positional_encoding', c, restricted=True, val_type=type)
|
||||
check_argument('encoder_type', c, restricted=True, val_type=str)
|
||||
check_argument('encoder_params', c, restricted=True, val_type=dict)
|
||||
check_argument('decoder_residual_conv_bn_params', c, restricted=True, val_type=dict)
|
||||
if c["model"].lower in ["speedy_speech", "align_tts"]:
|
||||
check_argument("positional_encoding", c, restricted=True, val_type=type)
|
||||
check_argument("encoder_type", c, restricted=True, val_type=str)
|
||||
check_argument("encoder_params", c, restricted=True, val_type=dict)
|
||||
check_argument("decoder_residual_conv_bn_params", c, restricted=True, val_type=dict)
|
||||
|
||||
# GlowTTS parameters
|
||||
check_argument('encoder_type', c, restricted=not is_tacotron(c), val_type=str)
|
||||
check_argument("encoder_type", c, restricted=not is_tacotron(c), val_type=str)
|
||||
|
||||
# tensorboard
|
||||
check_argument('print_step', c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument('save_step', c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument('checkpoint', c, restricted=True, val_type=bool)
|
||||
check_argument('tb_model_param_stats', c, restricted=True, val_type=bool)
|
||||
check_argument("print_step", c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument("tb_plot_step", c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument("save_step", c, restricted=True, val_type=int, min_val=1)
|
||||
check_argument("checkpoint", c, restricted=True, val_type=bool)
|
||||
check_argument("tb_model_param_stats", c, restricted=True, val_type=bool)
|
||||
|
||||
# dataloading
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from TTS.tts.utils.text import cleaners
|
||||
check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners))
|
||||
check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool)
|
||||
check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10)
|
||||
check_argument('compute_input_seq_cache', c, restricted=True, val_type=bool)
|
||||
|
||||
check_argument("text_cleaner", c, restricted=True, val_type=str, enum_list=dir(cleaners))
|
||||
check_argument("enable_eos_bos_chars", c, restricted=True, val_type=bool)
|
||||
check_argument("num_loader_workers", c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument("num_val_loader_workers", c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument("batch_group_size", c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument("min_seq_len", c, restricted=True, val_type=int, min_val=0)
|
||||
check_argument("max_seq_len", c, restricted=True, val_type=int, min_val=10)
|
||||
check_argument("compute_input_seq_cache", c, restricted=True, val_type=bool)
|
||||
|
||||
# paths
|
||||
check_argument('output_path', c, restricted=True, val_type=str)
|
||||
check_argument("output_path", c, restricted=True, val_type=str)
|
||||
|
||||
# multi-speaker and gst
|
||||
check_argument('use_speaker_embedding', c, restricted=True, val_type=bool)
|
||||
check_argument('use_external_speaker_embedding_file', c, restricted=c['use_speaker_embedding'], val_type=bool)
|
||||
check_argument('external_speaker_embedding_file', c, restricted=c['use_external_speaker_embedding_file'], val_type=str)
|
||||
if c['model'].lower() in ['tacotron', 'tacotron2'] and c['use_gst']:
|
||||
check_argument('use_gst', c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('gst', c, restricted=is_tacotron(c), val_type=dict)
|
||||
check_argument('gst_style_input', c['gst'], restricted=is_tacotron(c), val_type=[str, dict])
|
||||
check_argument('gst_embedding_dim', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)
|
||||
check_argument('gst_use_speaker_embedding', c['gst'], restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument('gst_num_heads', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10)
|
||||
check_argument('gst_style_tokens', c['gst'], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000)
|
||||
check_argument("use_speaker_embedding", c, restricted=True, val_type=bool)
|
||||
check_argument("use_external_speaker_embedding_file", c, restricted=c["use_speaker_embedding"], val_type=bool)
|
||||
check_argument(
|
||||
"external_speaker_embedding_file", c, restricted=c["use_external_speaker_embedding_file"], val_type=str
|
||||
)
|
||||
if c["model"].lower() in ["tacotron", "tacotron2"] and c["use_gst"]:
|
||||
check_argument("use_gst", c, restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("gst", c, restricted=is_tacotron(c), val_type=dict)
|
||||
check_argument("gst_style_input", c["gst"], restricted=is_tacotron(c), val_type=[str, dict])
|
||||
check_argument("gst_embedding_dim", c["gst"], restricted=is_tacotron(c), val_type=int, min_val=0, max_val=1000)
|
||||
check_argument("gst_use_speaker_embedding", c["gst"], restricted=is_tacotron(c), val_type=bool)
|
||||
check_argument("gst_num_heads", c["gst"], restricted=is_tacotron(c), val_type=int, min_val=2, max_val=10)
|
||||
check_argument("gst_style_tokens", c["gst"], restricted=is_tacotron(c), val_type=int, min_val=1, max_val=1000)
|
||||
|
||||
# datasets - checking only the first entry
|
||||
check_argument('datasets', c, restricted=True, val_type=list)
|
||||
for dataset_entry in c['datasets']:
|
||||
check_argument('name', dataset_entry, restricted=True, val_type=str)
|
||||
check_argument('path', dataset_entry, restricted=True, val_type=str)
|
||||
check_argument('meta_file_train', dataset_entry, restricted=True, val_type=[str, list])
|
||||
check_argument('meta_file_val', dataset_entry, restricted=True, val_type=str)
|
||||
check_argument("datasets", c, restricted=True, val_type=list)
|
||||
for dataset_entry in c["datasets"]:
|
||||
check_argument("name", dataset_entry, restricted=True, val_type=str)
|
||||
check_argument("path", dataset_entry, restricted=True, val_type=str)
|
||||
check_argument("meta_file_train", dataset_entry, restricted=True, val_type=[str, list])
|
||||
check_argument("meta_file_val", dataset_entry, restricted=True, val_type=str)
|
||||
|
|
|
@ -6,7 +6,6 @@ import pickle as pickle_tts
|
|||
from TTS.utils.io import RenamingUnpickler
|
||||
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False): # pylint: disable=redefined-builtin
|
||||
"""Load ```TTS.tts.models``` checkpoints.
|
||||
|
||||
|
@ -20,33 +19,25 @@ def load_checkpoint(model, checkpoint_path, amp=None, use_cuda=False, eval=False
|
|||
[type]: [description]
|
||||
"""
|
||||
try:
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||||
except ModuleNotFoundError:
|
||||
pickle_tts.Unpickler = RenamingUnpickler
|
||||
state = torch.load(checkpoint_path, map_location=torch.device('cpu'), pickle_module=pickle_tts)
|
||||
model.load_state_dict(state['model'])
|
||||
if amp and 'amp' in state:
|
||||
amp.load_state_dict(state['amp'])
|
||||
state = torch.load(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts)
|
||||
model.load_state_dict(state["model"])
|
||||
if amp and "amp" in state:
|
||||
amp.load_state_dict(state["amp"])
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
# set model stepsize
|
||||
if hasattr(model.decoder, 'r'):
|
||||
model.decoder.set_r(state['r'])
|
||||
print(" > Model r: ", state['r'])
|
||||
if hasattr(model.decoder, "r"):
|
||||
model.decoder.set_r(state["r"])
|
||||
print(" > Model r: ", state["r"])
|
||||
if eval:
|
||||
model.eval()
|
||||
return model, state
|
||||
|
||||
|
||||
def save_model(model,
|
||||
optimizer,
|
||||
current_step,
|
||||
epoch,
|
||||
r,
|
||||
output_path,
|
||||
characters,
|
||||
amp_state_dict=None,
|
||||
**kwargs):
|
||||
def save_model(model, optimizer, current_step, epoch, r, output_path, characters, amp_state_dict=None, **kwargs):
|
||||
"""Save ```TTS.tts.models``` states with extra fields.
|
||||
|
||||
Args:
|
||||
|
@ -59,27 +50,26 @@ def save_model(model,
|
|||
characters (list): list of characters used in the model.
|
||||
amp_state_dict (state_dict, optional): Apex.amp state dict if Apex is enabled. Defaults to None.
|
||||
"""
|
||||
if hasattr(model, 'module'):
|
||||
if hasattr(model, "module"):
|
||||
model_state = model.module.state_dict()
|
||||
else:
|
||||
model_state = model.state_dict()
|
||||
state = {
|
||||
'model': model_state,
|
||||
'optimizer': optimizer.state_dict() if optimizer is not None else None,
|
||||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
'date': datetime.date.today().strftime("%B %d, %Y"),
|
||||
'r': r,
|
||||
'characters': characters
|
||||
"model": model_state,
|
||||
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
||||
"step": current_step,
|
||||
"epoch": epoch,
|
||||
"date": datetime.date.today().strftime("%B %d, %Y"),
|
||||
"r": r,
|
||||
"characters": characters,
|
||||
}
|
||||
if amp_state_dict:
|
||||
state['amp'] = amp_state_dict
|
||||
state["amp"] = amp_state_dict
|
||||
state.update(kwargs)
|
||||
torch.save(state, output_path)
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder,
|
||||
characters, **kwargs):
|
||||
def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder, characters, **kwargs):
|
||||
"""Save model checkpoint, intended for saving checkpoints at training.
|
||||
|
||||
Args:
|
||||
|
@ -91,14 +81,15 @@ def save_checkpoint(model, optimizer, current_step, epoch, r, output_folder,
|
|||
output_path (str): output path to save the model file.
|
||||
characters (list): list of characters used in the model.
|
||||
"""
|
||||
file_name = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||
file_name = "checkpoint_{}.pth.tar".format(current_step)
|
||||
checkpoint_path = os.path.join(output_folder, file_name)
|
||||
print(" > CHECKPOINT : {}".format(checkpoint_path))
|
||||
save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, **kwargs)
|
||||
|
||||
|
||||
def save_best_model(target_loss, best_loss, model, optimizer, current_step,
|
||||
epoch, r, output_folder, characters, **kwargs):
|
||||
def save_best_model(
|
||||
target_loss, best_loss, model, optimizer, current_step, epoch, r, output_folder, characters, **kwargs
|
||||
):
|
||||
"""Save model checkpoint, intended for saving the best model after each epoch.
|
||||
It compares the current model loss with the best loss so far and saves the
|
||||
model if the current loss is better.
|
||||
|
@ -118,9 +109,11 @@ def save_best_model(target_loss, best_loss, model, optimizer, current_step,
|
|||
float: updated current best loss.
|
||||
"""
|
||||
if target_loss < best_loss:
|
||||
file_name = 'best_model.pth.tar'
|
||||
file_name = "best_model.pth.tar"
|
||||
checkpoint_path = os.path.join(output_folder, file_name)
|
||||
print(" >> BEST MODEL : {}".format(checkpoint_path))
|
||||
save_model(model, optimizer, current_step, epoch, r, checkpoint_path, characters, model_loss=target_loss, **kwargs)
|
||||
save_model(
|
||||
model, optimizer, current_step, epoch, r, checkpoint_path, characters, model_loss=target_loss, **kwargs
|
||||
)
|
||||
best_loss = target_loss
|
||||
return best_loss
|
||||
|
|
|
@ -10,7 +10,7 @@ def make_speakers_json_path(out_path):
|
|||
def load_speaker_mapping(out_path):
|
||||
"""Loads speaker mapping if already present."""
|
||||
try:
|
||||
if os.path.splitext(out_path)[1] == '.json':
|
||||
if os.path.splitext(out_path)[1] == ".json":
|
||||
json_file = out_path
|
||||
else:
|
||||
json_file = make_speakers_json_path(out_path)
|
||||
|
@ -19,6 +19,7 @@ def load_speaker_mapping(out_path):
|
|||
except FileNotFoundError:
|
||||
return {}
|
||||
|
||||
|
||||
def save_speaker_mapping(out_path, speaker_mapping):
|
||||
"""Saves speaker mapping if not yet present."""
|
||||
speakers_json_path = make_speakers_json_path(out_path)
|
||||
|
@ -31,40 +32,49 @@ def get_speakers(items):
|
|||
speakers = {e[2] for e in items}
|
||||
return sorted(speakers)
|
||||
|
||||
|
||||
def parse_speakers(c, args, meta_data_train, OUT_PATH):
|
||||
""" Returns number of speakers, speaker embedding shape and speaker mapping"""
|
||||
if c.use_speaker_embedding:
|
||||
speakers = get_speakers(meta_data_train)
|
||||
if args.restore_path:
|
||||
if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file
|
||||
if c.use_external_speaker_embedding_file: # if restore checkpoint and use External Embedding file
|
||||
prev_out_path = os.path.dirname(args.restore_path)
|
||||
speaker_mapping = load_speaker_mapping(prev_out_path)
|
||||
if not speaker_mapping:
|
||||
print("WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file")
|
||||
print(
|
||||
"WARNING: speakers.json was not found in restore_path, trying to use CONFIG.external_speaker_embedding_file"
|
||||
)
|
||||
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
|
||||
if not speaker_mapping:
|
||||
raise RuntimeError("You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file")
|
||||
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
|
||||
elif not c.use_external_speaker_embedding_file: # if restore checkpoint and don't use External Embedding file
|
||||
raise RuntimeError(
|
||||
"You must copy the file speakers.json to restore_path, or set a valid file in CONFIG.external_speaker_embedding_file"
|
||||
)
|
||||
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]["embedding"])
|
||||
elif (
|
||||
not c.use_external_speaker_embedding_file
|
||||
): # if restore checkpoint and don't use External Embedding file
|
||||
prev_out_path = os.path.dirname(args.restore_path)
|
||||
speaker_mapping = load_speaker_mapping(prev_out_path)
|
||||
speaker_embedding_dim = None
|
||||
assert all([speaker in speaker_mapping
|
||||
for speaker in speakers]), "As of now you, you cannot " \
|
||||
"introduce new speakers to " \
|
||||
"a previously trained model."
|
||||
elif c.use_external_speaker_embedding_file and c.external_speaker_embedding_file: # if start new train using External Embedding file
|
||||
assert all(speaker in speaker_mapping for speaker in speakers), (
|
||||
"As of now you, you cannot " "introduce new speakers to " "a previously trained model."
|
||||
)
|
||||
elif (
|
||||
c.use_external_speaker_embedding_file and c.external_speaker_embedding_file
|
||||
): # if start new train using External Embedding file
|
||||
speaker_mapping = load_speaker_mapping(c.external_speaker_embedding_file)
|
||||
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]['embedding'])
|
||||
elif c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file: # if start new train using External Embedding file and don't pass external embedding file
|
||||
speaker_embedding_dim = len(speaker_mapping[list(speaker_mapping.keys())[0]]["embedding"])
|
||||
elif (
|
||||
c.use_external_speaker_embedding_file and not c.external_speaker_embedding_file
|
||||
): # if start new train using External Embedding file and don't pass external embedding file
|
||||
raise "use_external_speaker_embedding_file is True, so you need pass a external speaker embedding file, run GE2E-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb or AngularPrototypical-Speaker_Encoder-ExtractSpeakerEmbeddings-by-sample.ipynb notebook in notebooks/ folder"
|
||||
else: # if start new train and don't use External Embedding file
|
||||
else: # if start new train and don't use External Embedding file
|
||||
speaker_mapping = {name: i for i, name in enumerate(speakers)}
|
||||
speaker_embedding_dim = None
|
||||
save_speaker_mapping(OUT_PATH, speaker_mapping)
|
||||
num_speakers = len(speaker_mapping)
|
||||
print(" > Training with {} speakers: {}".format(
|
||||
len(speakers), ", ".join(speakers)))
|
||||
print(" > Training with {} speakers: {}".format(len(speakers), ", ".join(speakers)))
|
||||
else:
|
||||
num_speakers = 0
|
||||
speaker_embedding_dim = None
|
||||
|
|
|
@ -8,8 +8,9 @@ from torch.autograd import Variable
|
|||
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
|
||||
return gauss/gauss.sum()
|
||||
gauss = torch.Tensor([exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) for x in range(window_size)])
|
||||
return gauss / gauss.sum()
|
||||
|
||||
|
||||
def create_window(window_size, channel):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
|
@ -24,25 +25,22 @@ def _ssim(img1, img2, window, window_size, channel, size_average=True):
|
|||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1*mu2
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = F.conv2d(
|
||||
img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
||||
sigma2_sq = F.conv2d(
|
||||
img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
||||
sigma12 = F.conv2d(
|
||||
img1 * img2, window, padding=window_size // 2,
|
||||
groups=channel) - mu1_mu2
|
||||
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
|
||||
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
|
||||
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
|
||||
|
||||
C1 = 0.01**2
|
||||
C2 = 0.03**2
|
||||
C1 = 0.01 ** 2
|
||||
C2 = 0.03 ** 2
|
||||
|
||||
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
||||
|
||||
if size_average:
|
||||
return ssim_map.mean()
|
||||
return ssim_map.mean(1).mean(1).mean(1)
|
||||
|
||||
|
||||
class SSIM(torch.nn.Module):
|
||||
def __init__(self, window_size=11, size_average=True):
|
||||
super().__init__()
|
||||
|
@ -66,7 +64,6 @@ class SSIM(torch.nn.Module):
|
|||
self.window = window
|
||||
self.channel = channel
|
||||
|
||||
|
||||
return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import os
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
||||
|
||||
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
||||
import pkg_resources
|
||||
installed = {pkg.key for pkg in pkg_resources.working_set} #pylint: disable=not-an-iterable
|
||||
if 'tensorflow' in installed or 'tensorflow-gpu' in installed:
|
||||
|
||||
installed = {pkg.key for pkg in pkg_resources.working_set} # pylint: disable=not-an-iterable
|
||||
if "tensorflow" in installed or "tensorflow-gpu" in installed:
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
import numpy as np
|
||||
|
@ -14,19 +16,26 @@ def text_to_seqvec(text, CONFIG):
|
|||
# text ot phonemes to sequence vector
|
||||
if CONFIG.use_phonemes:
|
||||
seq = np.asarray(
|
||||
phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language,
|
||||
CONFIG.enable_eos_bos_chars,
|
||||
tp=CONFIG.characters if 'characters' in CONFIG.keys() else None,
|
||||
add_blank=CONFIG['add_blank'] if 'add_blank' in CONFIG.keys() else False),
|
||||
dtype=np.int32)
|
||||
phoneme_to_sequence(
|
||||
text,
|
||||
text_cleaner,
|
||||
CONFIG.phoneme_language,
|
||||
CONFIG.enable_eos_bos_chars,
|
||||
tp=CONFIG.characters if "characters" in CONFIG.keys() else None,
|
||||
add_blank=CONFIG["add_blank"] if "add_blank" in CONFIG.keys() else False,
|
||||
),
|
||||
dtype=np.int32,
|
||||
)
|
||||
else:
|
||||
seq = np.asarray(text_to_sequence(
|
||||
text,
|
||||
text_cleaner,
|
||||
tp=CONFIG.characters if 'characters' in CONFIG.keys() else None,
|
||||
add_blank=CONFIG['add_blank']
|
||||
if 'add_blank' in CONFIG.keys() else False),
|
||||
dtype=np.int32)
|
||||
seq = np.asarray(
|
||||
text_to_sequence(
|
||||
text,
|
||||
text_cleaner,
|
||||
tp=CONFIG.characters if "characters" in CONFIG.keys() else None,
|
||||
add_blank=CONFIG["add_blank"] if "add_blank" in CONFIG.keys() else False,
|
||||
),
|
||||
dtype=np.int32,
|
||||
)
|
||||
return seq
|
||||
|
||||
|
||||
|
@ -47,86 +56,95 @@ def numpy_to_tf(np_array, dtype):
|
|||
|
||||
|
||||
def compute_style_mel(style_wav, ap, cuda=False):
|
||||
style_mel = torch.FloatTensor(ap.melspectrogram(
|
||||
ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
|
||||
style_mel = torch.FloatTensor(ap.melspectrogram(ap.load_wav(style_wav, sr=ap.sample_rate))).unsqueeze(0)
|
||||
if cuda:
|
||||
return style_mel.cuda()
|
||||
return style_mel
|
||||
|
||||
|
||||
def run_model_torch(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None, speaker_embeddings=None):
|
||||
if 'tacotron' in CONFIG.model.lower():
|
||||
if "tacotron" in CONFIG.model.lower():
|
||||
if CONFIG.use_gst:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
|
||||
inputs, style_mel=style_mel, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
else:
|
||||
if truncated:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference_truncated(
|
||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
|
||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
else:
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model.inference(
|
||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings)
|
||||
elif 'glow' in CONFIG.model.lower():
|
||||
inputs, speaker_ids=speaker_id, speaker_embeddings=speaker_embeddings
|
||||
)
|
||||
elif "glow" in CONFIG.model.lower():
|
||||
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
|
||||
if hasattr(model, 'module'):
|
||||
if hasattr(model, "module"):
|
||||
# distributed model
|
||||
postnet_output, _, _, _, alignments, _, _ = model.module.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings)
|
||||
postnet_output, _, _, _, alignments, _, _ = model.module.inference(
|
||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
||||
)
|
||||
else:
|
||||
postnet_output, _, _, _, alignments, _, _ = model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings)
|
||||
postnet_output, _, _, _, alignments, _, _ = model.inference(
|
||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
||||
)
|
||||
postnet_output = postnet_output.permute(0, 2, 1)
|
||||
# these only belong to tacotron models.
|
||||
decoder_output = None
|
||||
stop_tokens = None
|
||||
elif CONFIG.model.lower() in ['speedy_speech', 'align_tts']:
|
||||
elif CONFIG.model.lower() in ["speedy_speech", "align_tts"]:
|
||||
inputs_lengths = torch.tensor(inputs.shape[1:2]).to(inputs.device) # pylint: disable=not-callable
|
||||
if hasattr(model, 'module'):
|
||||
if hasattr(model, "module"):
|
||||
# distributed model
|
||||
postnet_output, alignments = model.module.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings)
|
||||
postnet_output, alignments = model.module.inference(
|
||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
||||
)
|
||||
else:
|
||||
postnet_output, alignments = model.inference(inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings)
|
||||
postnet_output, alignments = model.inference(
|
||||
inputs, inputs_lengths, g=speaker_id if speaker_id is not None else speaker_embeddings
|
||||
)
|
||||
postnet_output = postnet_output.permute(0, 2, 1)
|
||||
# these only belong to tacotron models.
|
||||
decoder_output = None
|
||||
stop_tokens = None
|
||||
else:
|
||||
raise ValueError('[!] Unknown model name.')
|
||||
raise ValueError("[!] Unknown model name.")
|
||||
return decoder_output, postnet_output, alignments, stop_tokens
|
||||
|
||||
|
||||
def run_model_tf(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
|
||||
if CONFIG.use_gst and style_mel is not None:
|
||||
raise NotImplementedError(' [!] GST inference not implemented for TF')
|
||||
raise NotImplementedError(" [!] GST inference not implemented for TF")
|
||||
if truncated:
|
||||
raise NotImplementedError(' [!] Truncated inference not implemented for TF')
|
||||
raise NotImplementedError(" [!] Truncated inference not implemented for TF")
|
||||
if speaker_id is not None:
|
||||
raise NotImplementedError(' [!] Multi-Speaker not implemented for TF')
|
||||
raise NotImplementedError(" [!] Multi-Speaker not implemented for TF")
|
||||
# TODO: handle multispeaker case
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(
|
||||
inputs, training=False)
|
||||
decoder_output, postnet_output, alignments, stop_tokens = model(inputs, training=False)
|
||||
return decoder_output, postnet_output, alignments, stop_tokens
|
||||
|
||||
|
||||
def run_model_tflite(model, inputs, CONFIG, truncated, speaker_id=None, style_mel=None):
|
||||
if CONFIG.use_gst and style_mel is not None:
|
||||
raise NotImplementedError(' [!] GST inference not implemented for TfLite')
|
||||
raise NotImplementedError(" [!] GST inference not implemented for TfLite")
|
||||
if truncated:
|
||||
raise NotImplementedError(' [!] Truncated inference not implemented for TfLite')
|
||||
raise NotImplementedError(" [!] Truncated inference not implemented for TfLite")
|
||||
if speaker_id is not None:
|
||||
raise NotImplementedError(' [!] Multi-Speaker not implemented for TfLite')
|
||||
raise NotImplementedError(" [!] Multi-Speaker not implemented for TfLite")
|
||||
# get input and output details
|
||||
input_details = model.get_input_details()
|
||||
output_details = model.get_output_details()
|
||||
# reshape input tensor for the new input shape
|
||||
model.resize_tensor_input(input_details[0]['index'], inputs.shape)
|
||||
model.resize_tensor_input(input_details[0]["index"], inputs.shape)
|
||||
model.allocate_tensors()
|
||||
detail = input_details[0]
|
||||
# input_shape = detail['shape']
|
||||
model.set_tensor(detail['index'], inputs)
|
||||
model.set_tensor(detail["index"], inputs)
|
||||
# run the model
|
||||
model.invoke()
|
||||
# collect outputs
|
||||
decoder_output = model.get_tensor(output_details[0]['index'])
|
||||
postnet_output = model.get_tensor(output_details[1]['index'])
|
||||
decoder_output = model.get_tensor(output_details[0]["index"])
|
||||
postnet_output = model.get_tensor(output_details[1]["index"])
|
||||
# tflite model only returns feature frames
|
||||
return decoder_output, postnet_output, None, None
|
||||
|
||||
|
@ -154,7 +172,7 @@ def parse_outputs_tflite(postnet_output, decoder_output):
|
|||
|
||||
|
||||
def trim_silence(wav, ap):
|
||||
return wav[:ap.find_endpoint(wav)]
|
||||
return wav[: ap.find_endpoint(wav)]
|
||||
|
||||
|
||||
def inv_spectrogram(postnet_output, ap, CONFIG):
|
||||
|
@ -186,13 +204,13 @@ def embedding_to_torch(speaker_embedding, cuda=False):
|
|||
|
||||
# TODO: perform GL with pytorch for batching
|
||||
def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
|
||||
'''Apply griffin-lim to each sample iterating throught the first dimension.
|
||||
"""Apply griffin-lim to each sample iterating throught the first dimension.
|
||||
Args:
|
||||
inputs (Tensor or np.Array): Features to be converted by GL. First dimension is the batch size.
|
||||
input_lens (Tensor or np.Array): 1D array of sample lengths.
|
||||
CONFIG (Dict): TTS config.
|
||||
ap (AudioProcessor): TTS audio processor.
|
||||
'''
|
||||
"""
|
||||
wavs = []
|
||||
for idx, spec in enumerate(inputs):
|
||||
wav_len = (input_lens[idx] * ap.hop_length) - ap.hop_length # inverse librosa padding
|
||||
|
@ -202,39 +220,41 @@ def apply_griffin_lim(inputs, input_lens, CONFIG, ap):
|
|||
return wavs
|
||||
|
||||
|
||||
def synthesis(model,
|
||||
text,
|
||||
CONFIG,
|
||||
use_cuda,
|
||||
ap,
|
||||
speaker_id=None,
|
||||
style_wav=None,
|
||||
truncated=False,
|
||||
enable_eos_bos_chars=False, #pylint: disable=unused-argument
|
||||
use_griffin_lim=False,
|
||||
do_trim_silence=False,
|
||||
speaker_embedding=None,
|
||||
backend='torch'):
|
||||
def synthesis(
|
||||
model,
|
||||
text,
|
||||
CONFIG,
|
||||
use_cuda,
|
||||
ap,
|
||||
speaker_id=None,
|
||||
style_wav=None,
|
||||
truncated=False,
|
||||
enable_eos_bos_chars=False, # pylint: disable=unused-argument
|
||||
use_griffin_lim=False,
|
||||
do_trim_silence=False,
|
||||
speaker_embedding=None,
|
||||
backend="torch",
|
||||
):
|
||||
"""Synthesize voice for the given text.
|
||||
|
||||
Args:
|
||||
model (TTS.tts.models): model to synthesize.
|
||||
text (str): target text
|
||||
CONFIG (dict): config dictionary to be loaded from config.json.
|
||||
use_cuda (bool): enable cuda.
|
||||
ap (TTS.tts.utils.audio.AudioProcessor): audio processor to process
|
||||
model outputs.
|
||||
speaker_id (int): id of speaker
|
||||
style_wav (str | Dict[str, float]): Uses for style embedding of GST.
|
||||
truncated (bool): keep model states after inference. It can be used
|
||||
for continuous inference at long texts.
|
||||
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
|
||||
do_trim_silence (bool): trim silence after synthesis.
|
||||
backend (str): tf or torch
|
||||
Args:
|
||||
model (TTS.tts.models): model to synthesize.
|
||||
text (str): target text
|
||||
CONFIG (dict): config dictionary to be loaded from config.json.
|
||||
use_cuda (bool): enable cuda.
|
||||
ap (TTS.tts.utils.audio.AudioProcessor): audio processor to process
|
||||
model outputs.
|
||||
speaker_id (int): id of speaker
|
||||
style_wav (str | Dict[str, float]): Uses for style embedding of GST.
|
||||
truncated (bool): keep model states after inference. It can be used
|
||||
for continuous inference at long texts.
|
||||
enable_eos_bos_chars (bool): enable special chars for end of sentence and start of sentence.
|
||||
do_trim_silence (bool): trim silence after synthesis.
|
||||
backend (str): tf or torch
|
||||
"""
|
||||
# GST processing
|
||||
style_mel = None
|
||||
if 'use_gst' in CONFIG.keys() and CONFIG.use_gst and style_wav is not None:
|
||||
if "use_gst" in CONFIG.keys() and CONFIG.use_gst and style_wav is not None:
|
||||
if isinstance(style_wav, dict):
|
||||
style_mel = style_wav
|
||||
else:
|
||||
|
@ -242,7 +262,7 @@ def synthesis(model,
|
|||
# preprocess the given text
|
||||
inputs = text_to_seqvec(text, CONFIG)
|
||||
# pass tensors to backend
|
||||
if backend == 'torch':
|
||||
if backend == "torch":
|
||||
if speaker_id is not None:
|
||||
speaker_id = id_to_torch(speaker_id, cuda=use_cuda)
|
||||
|
||||
|
@ -253,31 +273,35 @@ def synthesis(model,
|
|||
style_mel = numpy_to_torch(style_mel, torch.float, cuda=use_cuda)
|
||||
inputs = numpy_to_torch(inputs, torch.long, cuda=use_cuda)
|
||||
inputs = inputs.unsqueeze(0)
|
||||
elif backend == 'tf':
|
||||
elif backend == "tf":
|
||||
# TODO: handle speaker id for tf model
|
||||
style_mel = numpy_to_tf(style_mel, tf.float32)
|
||||
inputs = numpy_to_tf(inputs, tf.int32)
|
||||
inputs = tf.expand_dims(inputs, 0)
|
||||
elif backend == 'tflite':
|
||||
elif backend == "tflite":
|
||||
style_mel = numpy_to_tf(style_mel, tf.float32)
|
||||
inputs = numpy_to_tf(inputs, tf.int32)
|
||||
inputs = tf.expand_dims(inputs, 0)
|
||||
# synthesize voice
|
||||
if backend == 'torch':
|
||||
if backend == "torch":
|
||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_torch(
|
||||
model, inputs, CONFIG, truncated, speaker_id, style_mel, speaker_embeddings=speaker_embedding)
|
||||
model, inputs, CONFIG, truncated, speaker_id, style_mel, speaker_embeddings=speaker_embedding
|
||||
)
|
||||
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_torch(
|
||||
postnet_output, decoder_output, alignments, stop_tokens)
|
||||
elif backend == 'tf':
|
||||
postnet_output, decoder_output, alignments, stop_tokens
|
||||
)
|
||||
elif backend == "tf":
|
||||
decoder_output, postnet_output, alignments, stop_tokens = run_model_tf(
|
||||
model, inputs, CONFIG, truncated, speaker_id, style_mel)
|
||||
model, inputs, CONFIG, truncated, speaker_id, style_mel
|
||||
)
|
||||
postnet_output, decoder_output, alignment, stop_tokens = parse_outputs_tf(
|
||||
postnet_output, decoder_output, alignments, stop_tokens)
|
||||
elif backend == 'tflite':
|
||||
postnet_output, decoder_output, alignments, stop_tokens
|
||||
)
|
||||
elif backend == "tflite":
|
||||
decoder_output, postnet_output, alignment, stop_tokens = run_model_tflite(
|
||||
model, inputs, CONFIG, truncated, speaker_id, style_mel)
|
||||
postnet_output, decoder_output = parse_outputs_tflite(
|
||||
postnet_output, decoder_output)
|
||||
model, inputs, CONFIG, truncated, speaker_id, style_mel
|
||||
)
|
||||
postnet_output, decoder_output = parse_outputs_tflite(postnet_output, decoder_output)
|
||||
# convert outputs to numpy
|
||||
# plot results
|
||||
wav = None
|
||||
|
|
|
@ -6,8 +6,7 @@ import phonemizer
|
|||
from packaging import version
|
||||
from phonemizer.phonemize import phonemize
|
||||
from TTS.tts.utils.text import cleaners
|
||||
from TTS.tts.utils.text.symbols import (_bos, _eos, _punctuations,
|
||||
make_symbols, phonemes, symbols)
|
||||
from TTS.tts.utils.text.symbols import _bos, _eos, _punctuations, make_symbols, phonemes, symbols
|
||||
from TTS.tts.utils.chinese_mandarin.phonemizer import chinese_text_to_phonemes
|
||||
|
||||
|
||||
|
@ -22,14 +21,14 @@ _id_to_phonemes = {i: s for i, s in enumerate(phonemes)}
|
|||
_symbols = symbols
|
||||
_phonemes = phonemes
|
||||
# Regular expression matching text enclosed in curly braces:
|
||||
_CURLY_RE = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
||||
_CURLY_RE = re.compile(r"(.*?)\{(.+?)\}(.*)")
|
||||
|
||||
# Regular expression matching punctuations, ignoring empty space
|
||||
PHONEME_PUNCTUATION_PATTERN = r'['+_punctuations.replace(' ', '')+']+'
|
||||
PHONEME_PUNCTUATION_PATTERN = r"[" + _punctuations.replace(" ", "") + "]+"
|
||||
|
||||
|
||||
def text2phone(text, language):
|
||||
'''Convert graphemes to phonemes. For most of the languages, it calls
|
||||
"""Convert graphemes to phonemes. For most of the languages, it calls
|
||||
the phonemizer python library that calls espeak/espeak-ng. For chinese
|
||||
mandarin, it calls pypinyin + custom function for phonemizing
|
||||
Parameters:
|
||||
|
@ -38,60 +37,73 @@ def text2phone(text, language):
|
|||
Returns:
|
||||
ph (str): phonemes as a string seperated by "|"
|
||||
ph = "ɪ|g|ˈ|z|æ|m|p|ə|l"
|
||||
'''
|
||||
"""
|
||||
|
||||
# TO REVIEW : How to have a good implementation for this?
|
||||
if language == "zh-CN":
|
||||
ph = chinese_text_to_phonemes(text)
|
||||
return ph
|
||||
|
||||
|
||||
seperator = phonemizer.separator.Separator(' |', '', '|')
|
||||
#try:
|
||||
seperator = phonemizer.separator.Separator(" |", "", "|")
|
||||
# try:
|
||||
punctuations = re.findall(PHONEME_PUNCTUATION_PATTERN, text)
|
||||
if version.parse(phonemizer.__version__) < version.parse('2.1'):
|
||||
ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language)
|
||||
ph = ph[:-1].strip() # skip the last empty character
|
||||
if version.parse(phonemizer.__version__) < version.parse("2.1"):
|
||||
ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend="espeak", language=language)
|
||||
ph = ph[:-1].strip() # skip the last empty character
|
||||
# phonemizer does not tackle punctuations. Here we do.
|
||||
# Replace \n with matching punctuations.
|
||||
if punctuations:
|
||||
# if text ends with a punctuation.
|
||||
if text[-1] == punctuations[-1]:
|
||||
for punct in punctuations[:-1]:
|
||||
ph = ph.replace('| |\n', '|'+punct+'| |', 1)
|
||||
ph = ph.replace("| |\n", "|" + punct + "| |", 1)
|
||||
ph = ph + punctuations[-1]
|
||||
else:
|
||||
for punct in punctuations:
|
||||
ph = ph.replace('| |\n', '|'+punct+'| |', 1)
|
||||
elif version.parse(phonemizer.__version__) >= version.parse('2.1'):
|
||||
ph = phonemize(text, separator=seperator, strip=False, njobs=1, backend='espeak', language=language, preserve_punctuation=True, language_switch='remove-flags')
|
||||
ph = ph.replace("| |\n", "|" + punct + "| |", 1)
|
||||
elif version.parse(phonemizer.__version__) >= version.parse("2.1"):
|
||||
ph = phonemize(
|
||||
text,
|
||||
separator=seperator,
|
||||
strip=False,
|
||||
njobs=1,
|
||||
backend="espeak",
|
||||
language=language,
|
||||
preserve_punctuation=True,
|
||||
language_switch="remove-flags",
|
||||
)
|
||||
# this is a simple fix for phonemizer.
|
||||
# https://github.com/bootphon/phonemizer/issues/32
|
||||
if punctuations:
|
||||
for punctuation in punctuations:
|
||||
ph = ph.replace(f"| |{punctuation} ", f"|{punctuation}| |").replace(f"| |{punctuation}", f"|{punctuation}| |")
|
||||
ph = ph.replace(f"| |{punctuation} ", f"|{punctuation}| |").replace(
|
||||
f"| |{punctuation}", f"|{punctuation}| |"
|
||||
)
|
||||
ph = ph[:-3]
|
||||
else:
|
||||
raise RuntimeError(" [!] Use 'phonemizer' version 2.1 or older.")
|
||||
|
||||
return ph
|
||||
|
||||
|
||||
def intersperse(sequence, token):
|
||||
result = [token] * (len(sequence) * 2 + 1)
|
||||
result[1::2] = sequence
|
||||
return result
|
||||
|
||||
|
||||
def pad_with_eos_bos(phoneme_sequence, tp=None):
|
||||
# pylint: disable=global-statement
|
||||
global _phonemes_to_id, _bos, _eos
|
||||
if tp:
|
||||
_bos = tp['bos']
|
||||
_eos = tp['eos']
|
||||
_bos = tp["bos"]
|
||||
_eos = tp["eos"]
|
||||
_, _phonemes = make_symbols(**tp)
|
||||
_phonemes_to_id = {s: i for i, s in enumerate(_phonemes)}
|
||||
|
||||
return [_phonemes_to_id[_bos]] + list(phoneme_sequence) + [_phonemes_to_id[_eos]]
|
||||
|
||||
|
||||
def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=None, add_blank=False):
|
||||
# pylint: disable=global-statement
|
||||
global _phonemes_to_id, _phonemes
|
||||
|
@ -105,23 +117,23 @@ def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=
|
|||
if to_phonemes is None:
|
||||
print("!! After phoneme conversion the result is None. -- {} ".format(clean_text))
|
||||
# iterate by skipping empty strings - NOTE: might be useful to keep it to have a better intonation.
|
||||
for phoneme in filter(None, to_phonemes.split('|')):
|
||||
for phoneme in filter(None, to_phonemes.split("|")):
|
||||
sequence += _phoneme_to_sequence(phoneme)
|
||||
# Append EOS char
|
||||
if enable_eos_bos:
|
||||
sequence = pad_with_eos_bos(sequence, tp=tp)
|
||||
if add_blank:
|
||||
sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes)
|
||||
sequence = intersperse(sequence, len(_phonemes)) # add a blank token (new), whose id number is len(_phonemes)
|
||||
return sequence
|
||||
|
||||
|
||||
def sequence_to_phoneme(sequence, tp=None, add_blank=False):
|
||||
# pylint: disable=global-statement
|
||||
'''Converts a sequence of IDs back to a string'''
|
||||
"""Converts a sequence of IDs back to a string"""
|
||||
global _id_to_phonemes, _phonemes
|
||||
if add_blank:
|
||||
sequence = list(filter(lambda x: x != len(_phonemes), sequence))
|
||||
result = ''
|
||||
result = ""
|
||||
if tp:
|
||||
_, _phonemes = make_symbols(**tp)
|
||||
_id_to_phonemes = {i: s for i, s in enumerate(_phonemes)}
|
||||
|
@ -130,22 +142,22 @@ def sequence_to_phoneme(sequence, tp=None, add_blank=False):
|
|||
if symbol_id in _id_to_phonemes:
|
||||
s = _id_to_phonemes[symbol_id]
|
||||
result += s
|
||||
return result.replace('}{', ' ')
|
||||
return result.replace("}{", " ")
|
||||
|
||||
|
||||
def text_to_sequence(text, cleaner_names, tp=None, add_blank=False):
|
||||
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
"""Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
||||
|
||||
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
||||
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
||||
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
||||
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
||||
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
cleaner_names: names of the cleaner functions to run the text through
|
||||
Args:
|
||||
text: string to convert to a sequence
|
||||
cleaner_names: names of the cleaner functions to run the text through
|
||||
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
'''
|
||||
Returns:
|
||||
List of integers corresponding to the symbols in the text
|
||||
"""
|
||||
# pylint: disable=global-statement
|
||||
global _symbol_to_id, _symbols
|
||||
if tp:
|
||||
|
@ -159,18 +171,17 @@ def text_to_sequence(text, cleaner_names, tp=None, add_blank=False):
|
|||
if not m:
|
||||
sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
|
||||
break
|
||||
sequence += _symbols_to_sequence(
|
||||
_clean_text(m.group(1), cleaner_names))
|
||||
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
|
||||
sequence += _arpabet_to_sequence(m.group(2))
|
||||
text = m.group(3)
|
||||
|
||||
if add_blank:
|
||||
sequence = intersperse(sequence, len(_symbols)) # add a blank token (new), whose id number is len(_symbols)
|
||||
sequence = intersperse(sequence, len(_symbols)) # add a blank token (new), whose id number is len(_symbols)
|
||||
return sequence
|
||||
|
||||
|
||||
def sequence_to_text(sequence, tp=None, add_blank=False):
|
||||
'''Converts a sequence of IDs back to a string'''
|
||||
"""Converts a sequence of IDs back to a string"""
|
||||
# pylint: disable=global-statement
|
||||
global _id_to_symbol, _symbols
|
||||
if add_blank:
|
||||
|
@ -180,22 +191,22 @@ def sequence_to_text(sequence, tp=None, add_blank=False):
|
|||
_symbols, _ = make_symbols(**tp)
|
||||
_id_to_symbol = {i: s for i, s in enumerate(_symbols)}
|
||||
|
||||
result = ''
|
||||
result = ""
|
||||
for symbol_id in sequence:
|
||||
if symbol_id in _id_to_symbol:
|
||||
s = _id_to_symbol[symbol_id]
|
||||
# Enclose ARPAbet back in curly braces:
|
||||
if len(s) > 1 and s[0] == '@':
|
||||
s = '{%s}' % s[1:]
|
||||
if len(s) > 1 and s[0] == "@":
|
||||
s = "{%s}" % s[1:]
|
||||
result += s
|
||||
return result.replace('}{', ' ')
|
||||
return result.replace("}{", " ")
|
||||
|
||||
|
||||
def _clean_text(text, cleaner_names):
|
||||
for name in cleaner_names:
|
||||
cleaner = getattr(cleaners, name)
|
||||
if not cleaner:
|
||||
raise Exception('Unknown cleaner: %s' % name)
|
||||
raise Exception("Unknown cleaner: %s" % name)
|
||||
text = cleaner(text)
|
||||
return text
|
||||
|
||||
|
@ -209,12 +220,12 @@ def _phoneme_to_sequence(phons):
|
|||
|
||||
|
||||
def _arpabet_to_sequence(text):
|
||||
return _symbols_to_sequence(['@' + s for s in text.split()])
|
||||
return _symbols_to_sequence(["@" + s for s in text.split()])
|
||||
|
||||
|
||||
def _should_keep_symbol(s):
|
||||
return s in _symbol_to_id and s not in ['~', '^', '_']
|
||||
return s in _symbol_to_id and s not in ["~", "^", "_"]
|
||||
|
||||
|
||||
def _should_keep_phoneme(p):
|
||||
return p in _phonemes_to_id and p not in ['~', '^', '_']
|
||||
return p in _phonemes_to_id and p not in ["~", "^", "_"]
|
||||
|
|
|
@ -1,66 +1,73 @@
|
|||
import re
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations in english:
|
||||
abbreviations_en = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
('mrs', 'misess'),
|
||||
('mr', 'mister'),
|
||||
('dr', 'doctor'),
|
||||
('st', 'saint'),
|
||||
('co', 'company'),
|
||||
('jr', 'junior'),
|
||||
('maj', 'major'),
|
||||
('gen', 'general'),
|
||||
('drs', 'doctors'),
|
||||
('rev', 'reverend'),
|
||||
('lt', 'lieutenant'),
|
||||
('hon', 'honorable'),
|
||||
('sgt', 'sergeant'),
|
||||
('capt', 'captain'),
|
||||
('esq', 'esquire'),
|
||||
('ltd', 'limited'),
|
||||
('col', 'colonel'),
|
||||
('ft', 'fort'),
|
||||
]]
|
||||
abbreviations_en = [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("mrs", "misess"),
|
||||
("mr", "mister"),
|
||||
("dr", "doctor"),
|
||||
("st", "saint"),
|
||||
("co", "company"),
|
||||
("jr", "junior"),
|
||||
("maj", "major"),
|
||||
("gen", "general"),
|
||||
("drs", "doctors"),
|
||||
("rev", "reverend"),
|
||||
("lt", "lieutenant"),
|
||||
("hon", "honorable"),
|
||||
("sgt", "sergeant"),
|
||||
("capt", "captain"),
|
||||
("esq", "esquire"),
|
||||
("ltd", "limited"),
|
||||
("col", "colonel"),
|
||||
("ft", "fort"),
|
||||
]
|
||||
]
|
||||
|
||||
# List of (regular expression, replacement) pairs for abbreviations in french:
|
||||
abbreviations_fr = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
('M', 'monsieur'),
|
||||
('Mlle', 'mademoiselle'),
|
||||
('Mlles', 'mesdemoiselles'),
|
||||
('Mme', 'Madame'),
|
||||
('Mmes', 'Mesdames'),
|
||||
('N.B', 'nota bene'),
|
||||
('M', 'monsieur'),
|
||||
('p.c.q', 'parce que'),
|
||||
('Pr', 'professeur'),
|
||||
('qqch', 'quelque chose'),
|
||||
('rdv', 'rendez-vous'),
|
||||
('max', 'maximum'),
|
||||
('min', 'minimum'),
|
||||
('no', 'numéro'),
|
||||
('adr', 'adresse'),
|
||||
('dr', 'docteur'),
|
||||
('st', 'saint'),
|
||||
('co', 'companie'),
|
||||
('jr', 'junior'),
|
||||
('sgt', 'sergent'),
|
||||
('capt', 'capitain'),
|
||||
('col', 'colonel'),
|
||||
('av', 'avenue'),
|
||||
('av. J.-C', 'avant Jésus-Christ'),
|
||||
('apr. J.-C', 'après Jésus-Christ'),
|
||||
('art', 'article'),
|
||||
('boul', 'boulevard'),
|
||||
('c.-à-d', 'c’est-à-dire'),
|
||||
('etc', 'et cetera'),
|
||||
('ex', 'exemple'),
|
||||
('excl', 'exclusivement'),
|
||||
('boul', 'boulevard'),
|
||||
]] + [(re.compile('\\b%s' % x[0]), x[1]) for x in [
|
||||
('Mlle', 'mademoiselle'),
|
||||
('Mlles', 'mesdemoiselles'),
|
||||
('Mme', 'Madame'),
|
||||
('Mmes', 'Mesdames'),
|
||||
]]
|
||||
abbreviations_fr = [
|
||||
(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
|
||||
for x in [
|
||||
("M", "monsieur"),
|
||||
("Mlle", "mademoiselle"),
|
||||
("Mlles", "mesdemoiselles"),
|
||||
("Mme", "Madame"),
|
||||
("Mmes", "Mesdames"),
|
||||
("N.B", "nota bene"),
|
||||
("M", "monsieur"),
|
||||
("p.c.q", "parce que"),
|
||||
("Pr", "professeur"),
|
||||
("qqch", "quelque chose"),
|
||||
("rdv", "rendez-vous"),
|
||||
("max", "maximum"),
|
||||
("min", "minimum"),
|
||||
("no", "numéro"),
|
||||
("adr", "adresse"),
|
||||
("dr", "docteur"),
|
||||
("st", "saint"),
|
||||
("co", "companie"),
|
||||
("jr", "junior"),
|
||||
("sgt", "sergent"),
|
||||
("capt", "capitain"),
|
||||
("col", "colonel"),
|
||||
("av", "avenue"),
|
||||
("av. J.-C", "avant Jésus-Christ"),
|
||||
("apr. J.-C", "après Jésus-Christ"),
|
||||
("art", "article"),
|
||||
("boul", "boulevard"),
|
||||
("c.-à-d", "c’est-à-dire"),
|
||||
("etc", "et cetera"),
|
||||
("ex", "exemple"),
|
||||
("excl", "exclusivement"),
|
||||
("boul", "boulevard"),
|
||||
]
|
||||
] + [
|
||||
(re.compile("\\b%s" % x[0]), x[1])
|
||||
for x in [
|
||||
("Mlle", "mademoiselle"),
|
||||
("Mlles", "mesdemoiselles"),
|
||||
("Mme", "Madame"),
|
||||
("Mmes", "Mesdames"),
|
||||
]
|
||||
]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
'''
|
||||
"""
|
||||
Cleaners are transformations that run over the input text at both training and eval time.
|
||||
|
||||
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
|
||||
|
@ -8,7 +8,7 @@ hyperparameter. Some cleaners are English-specific. You'll typically want to use
|
|||
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
|
||||
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
|
||||
the symbols in symbols.py to match your data).
|
||||
'''
|
||||
"""
|
||||
|
||||
import re
|
||||
from unidecode import unidecode
|
||||
|
@ -19,13 +19,13 @@ from TTS.tts.utils.chinese_mandarin.numbers import replace_numbers_to_characters
|
|||
|
||||
|
||||
# Regular expression matching whitespace:
|
||||
_whitespace_re = re.compile(r'\s+')
|
||||
_whitespace_re = re.compile(r"\s+")
|
||||
|
||||
|
||||
def expand_abbreviations(text, lang='en'):
|
||||
if lang == 'en':
|
||||
def expand_abbreviations(text, lang="en"):
|
||||
if lang == "en":
|
||||
_abbreviations = abbreviations_en
|
||||
elif lang == 'fr':
|
||||
elif lang == "fr":
|
||||
_abbreviations = abbreviations_fr
|
||||
for regex, replacement in _abbreviations:
|
||||
text = re.sub(regex, replacement, text)
|
||||
|
@ -41,7 +41,7 @@ def lowercase(text):
|
|||
|
||||
|
||||
def collapse_whitespace(text):
|
||||
return re.sub(_whitespace_re, ' ', text).strip()
|
||||
return re.sub(_whitespace_re, " ", text).strip()
|
||||
|
||||
|
||||
def convert_to_ascii(text):
|
||||
|
@ -49,30 +49,32 @@ def convert_to_ascii(text):
|
|||
|
||||
|
||||
def remove_aux_symbols(text):
|
||||
text = re.sub(r'[\<\>\(\)\[\]\"]+', '', text)
|
||||
text = re.sub(r"[\<\>\(\)\[\]\"]+", "", text)
|
||||
return text
|
||||
|
||||
def replace_symbols(text, lang='en'):
|
||||
text = text.replace(';', ',')
|
||||
text = text.replace('-', ' ')
|
||||
text = text.replace(':', ',')
|
||||
if lang == 'en':
|
||||
text = text.replace('&', ' and ')
|
||||
elif lang == 'fr':
|
||||
text = text.replace('&', ' et ')
|
||||
elif lang == 'pt':
|
||||
text = text.replace('&', ' e ')
|
||||
|
||||
def replace_symbols(text, lang="en"):
|
||||
text = text.replace(";", ",")
|
||||
text = text.replace("-", " ")
|
||||
text = text.replace(":", ",")
|
||||
if lang == "en":
|
||||
text = text.replace("&", " and ")
|
||||
elif lang == "fr":
|
||||
text = text.replace("&", " et ")
|
||||
elif lang == "pt":
|
||||
text = text.replace("&", " e ")
|
||||
return text
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
||||
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def transliteration_cleaners(text):
|
||||
'''Pipeline for non-English text that transliterates to ASCII.'''
|
||||
"""Pipeline for non-English text that transliterates to ASCII."""
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
|
@ -80,7 +82,7 @@ def transliteration_cleaners(text):
|
|||
|
||||
|
||||
def basic_german_cleaners(text):
|
||||
'''Pipeline for German text'''
|
||||
"""Pipeline for German text"""
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
@ -88,7 +90,7 @@ def basic_german_cleaners(text):
|
|||
|
||||
# TODO: elaborate it
|
||||
def basic_turkish_cleaners(text):
|
||||
'''Pipeline for Turkish text'''
|
||||
"""Pipeline for Turkish text"""
|
||||
text = text.replace("I", "ı")
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
|
@ -96,7 +98,7 @@ def basic_turkish_cleaners(text):
|
|||
|
||||
|
||||
def english_cleaners(text):
|
||||
'''Pipeline for English text, including number and abbreviation expansion.'''
|
||||
"""Pipeline for English text, including number and abbreviation expansion."""
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_time_english(text)
|
||||
|
@ -109,33 +111,33 @@ def english_cleaners(text):
|
|||
|
||||
|
||||
def french_cleaners(text):
|
||||
'''Pipeline for French text. There is no need to expand numbers, phonemizer already does that'''
|
||||
text = expand_abbreviations(text, lang='fr')
|
||||
"""Pipeline for French text. There is no need to expand numbers, phonemizer already does that"""
|
||||
text = expand_abbreviations(text, lang="fr")
|
||||
text = lowercase(text)
|
||||
text = replace_symbols(text, lang='fr')
|
||||
text = replace_symbols(text, lang="fr")
|
||||
text = remove_aux_symbols(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def portuguese_cleaners(text):
|
||||
'''Basic pipeline for Portuguese text. There is no need to expand abbreviation and
|
||||
numbers, phonemizer already does that'''
|
||||
"""Basic pipeline for Portuguese text. There is no need to expand abbreviation and
|
||||
numbers, phonemizer already does that"""
|
||||
text = lowercase(text)
|
||||
text = replace_symbols(text, lang='pt')
|
||||
text = replace_symbols(text, lang="pt")
|
||||
text = remove_aux_symbols(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def chinese_mandarin_cleaners(text: str) -> str:
|
||||
'''Basic pipeline for chinese'''
|
||||
"""Basic pipeline for chinese"""
|
||||
text = replace_numbers_to_characters_in_text(text)
|
||||
return text
|
||||
|
||||
|
||||
def phoneme_cleaners(text):
|
||||
'''Pipeline for phonemes mode, including number and abbreviation expansion.'''
|
||||
"""Pipeline for phonemes mode, including number and abbreviation expansion."""
|
||||
text = expand_numbers(text)
|
||||
text = convert_to_ascii(text)
|
||||
text = expand_abbreviations(text)
|
||||
|
|
|
@ -3,43 +3,116 @@
|
|||
import re
|
||||
|
||||
VALID_SYMBOLS = [
|
||||
'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1',
|
||||
'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0',
|
||||
'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0',
|
||||
'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0',
|
||||
'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG',
|
||||
'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH',
|
||||
'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W',
|
||||
'Y', 'Z', 'ZH'
|
||||
"AA",
|
||||
"AA0",
|
||||
"AA1",
|
||||
"AA2",
|
||||
"AE",
|
||||
"AE0",
|
||||
"AE1",
|
||||
"AE2",
|
||||
"AH",
|
||||
"AH0",
|
||||
"AH1",
|
||||
"AH2",
|
||||
"AO",
|
||||
"AO0",
|
||||
"AO1",
|
||||
"AO2",
|
||||
"AW",
|
||||
"AW0",
|
||||
"AW1",
|
||||
"AW2",
|
||||
"AY",
|
||||
"AY0",
|
||||
"AY1",
|
||||
"AY2",
|
||||
"B",
|
||||
"CH",
|
||||
"D",
|
||||
"DH",
|
||||
"EH",
|
||||
"EH0",
|
||||
"EH1",
|
||||
"EH2",
|
||||
"ER",
|
||||
"ER0",
|
||||
"ER1",
|
||||
"ER2",
|
||||
"EY",
|
||||
"EY0",
|
||||
"EY1",
|
||||
"EY2",
|
||||
"F",
|
||||
"G",
|
||||
"HH",
|
||||
"IH",
|
||||
"IH0",
|
||||
"IH1",
|
||||
"IH2",
|
||||
"IY",
|
||||
"IY0",
|
||||
"IY1",
|
||||
"IY2",
|
||||
"JH",
|
||||
"K",
|
||||
"L",
|
||||
"M",
|
||||
"N",
|
||||
"NG",
|
||||
"OW",
|
||||
"OW0",
|
||||
"OW1",
|
||||
"OW2",
|
||||
"OY",
|
||||
"OY0",
|
||||
"OY1",
|
||||
"OY2",
|
||||
"P",
|
||||
"R",
|
||||
"S",
|
||||
"SH",
|
||||
"T",
|
||||
"TH",
|
||||
"UH",
|
||||
"UH0",
|
||||
"UH1",
|
||||
"UH2",
|
||||
"UW",
|
||||
"UW0",
|
||||
"UW1",
|
||||
"UW2",
|
||||
"V",
|
||||
"W",
|
||||
"Y",
|
||||
"Z",
|
||||
"ZH",
|
||||
]
|
||||
|
||||
|
||||
class CMUDict:
|
||||
'''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict'''
|
||||
"""Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
|
||||
|
||||
def __init__(self, file_or_path, keep_ambiguous=True):
|
||||
if isinstance(file_or_path, str):
|
||||
with open(file_or_path, encoding='latin-1') as f:
|
||||
with open(file_or_path, encoding="latin-1") as f:
|
||||
entries = _parse_cmudict(f)
|
||||
else:
|
||||
entries = _parse_cmudict(file_or_path)
|
||||
if not keep_ambiguous:
|
||||
entries = {
|
||||
word: pron
|
||||
for word, pron in entries.items() if len(pron) == 1
|
||||
}
|
||||
entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
|
||||
self._entries = entries
|
||||
|
||||
def __len__(self):
|
||||
return len(self._entries)
|
||||
|
||||
def lookup(self, word):
|
||||
'''Returns list of ARPAbet pronunciations of the given word.'''
|
||||
"""Returns list of ARPAbet pronunciations of the given word."""
|
||||
return self._entries.get(word.upper())
|
||||
|
||||
@staticmethod
|
||||
def get_arpabet(word, cmudict, punctuation_symbols):
|
||||
first_symbol, last_symbol = '', ''
|
||||
first_symbol, last_symbol = "", ""
|
||||
if word and word[0] in punctuation_symbols:
|
||||
first_symbol = word[0]
|
||||
word = word[1:]
|
||||
|
@ -48,19 +121,19 @@ class CMUDict:
|
|||
word = word[:-1]
|
||||
arpabet = cmudict.lookup(word)
|
||||
if arpabet is not None:
|
||||
return first_symbol + '{%s}' % arpabet[0] + last_symbol
|
||||
return first_symbol + "{%s}" % arpabet[0] + last_symbol
|
||||
return first_symbol + word + last_symbol
|
||||
|
||||
|
||||
_alt_re = re.compile(r'\([0-9]+\)')
|
||||
_alt_re = re.compile(r"\([0-9]+\)")
|
||||
|
||||
|
||||
def _parse_cmudict(file):
|
||||
cmudict = {}
|
||||
for line in file:
|
||||
if line and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"):
|
||||
parts = line.split(' ')
|
||||
word = re.sub(_alt_re, '', parts[0])
|
||||
if line and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
|
||||
parts = line.split(" ")
|
||||
word = re.sub(_alt_re, "", parts[0])
|
||||
pronunciation = _get_pronunciation(parts[1])
|
||||
if pronunciation:
|
||||
if word in cmudict:
|
||||
|
@ -71,8 +144,8 @@ def _parse_cmudict(file):
|
|||
|
||||
|
||||
def _get_pronunciation(s):
|
||||
parts = s.strip().split(' ')
|
||||
parts = s.strip().split(" ")
|
||||
for part in parts:
|
||||
if part not in VALID_SYMBOLS:
|
||||
return None
|
||||
return ' '.join(parts)
|
||||
return " ".join(parts)
|
||||
|
|
|
@ -5,23 +5,23 @@ import re
|
|||
from typing import Dict
|
||||
|
||||
_inflect = inflect.engine()
|
||||
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
||||
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
||||
_currency_re = re.compile(r'(£|\$|¥)([0-9\,\.]*[0-9]+)')
|
||||
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
||||
_number_re = re.compile(r'-?[0-9]+')
|
||||
_comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
|
||||
_decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
|
||||
_currency_re = re.compile(r"(£|\$|¥)([0-9\,\.]*[0-9]+)")
|
||||
_ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
|
||||
_number_re = re.compile(r"-?[0-9]+")
|
||||
|
||||
|
||||
def _remove_commas(m):
|
||||
return m.group(1).replace(',', '')
|
||||
return m.group(1).replace(",", "")
|
||||
|
||||
|
||||
def _expand_decimal_point(m):
|
||||
return m.group(1).replace('.', ' point ')
|
||||
return m.group(1).replace(".", " point ")
|
||||
|
||||
|
||||
def __expand_currency(value: str, inflection: Dict[float, str]) -> str:
|
||||
parts = value.replace(",", "").split('.')
|
||||
parts = value.replace(",", "").split(".")
|
||||
if len(parts) > 2:
|
||||
return f"{value} {inflection[2]}" # Unexpected format
|
||||
text = []
|
||||
|
@ -31,7 +31,7 @@ def __expand_currency(value: str, inflection: Dict[float, str]) -> str:
|
|||
text.append(f"{integer} {integer_unit}")
|
||||
fraction = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
||||
if fraction > 0:
|
||||
fraction_unit = inflection.get(fraction/100, inflection[0.02])
|
||||
fraction_unit = inflection.get(fraction / 100, inflection[0.02])
|
||||
text.append(f"{fraction} {fraction_unit}")
|
||||
if len(text) == 0:
|
||||
return f"zero {inflection[2]}"
|
||||
|
@ -62,7 +62,7 @@ def _expand_currency(m: "re.Match") -> str:
|
|||
# TODO rin
|
||||
0.02: "sen",
|
||||
2: "yen",
|
||||
}
|
||||
},
|
||||
}
|
||||
unit = m.group(1)
|
||||
currency = currencies[unit]
|
||||
|
@ -78,16 +78,13 @@ def _expand_number(m):
|
|||
num = int(m.group(0))
|
||||
if 1000 < num < 3000:
|
||||
if num == 2000:
|
||||
return 'two thousand'
|
||||
return "two thousand"
|
||||
if 2000 < num < 2010:
|
||||
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
||||
return "two thousand " + _inflect.number_to_words(num % 100)
|
||||
if num % 100 == 0:
|
||||
return _inflect.number_to_words(num // 100) + ' hundred'
|
||||
return _inflect.number_to_words(num,
|
||||
andword='',
|
||||
zero='oh',
|
||||
group=2).replace(', ', ' ')
|
||||
return _inflect.number_to_words(num, andword='')
|
||||
return _inflect.number_to_words(num // 100) + " hundred"
|
||||
return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
|
||||
return _inflect.number_to_words(num, andword="")
|
||||
|
||||
|
||||
def normalize_numbers(text):
|
||||
|
|
|
@ -1,38 +1,41 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
'''
|
||||
"""
|
||||
Defines the set of symbols used in text input to the model.
|
||||
|
||||
The default is a set of ASCII characters that works well for English or text that has been run
|
||||
through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def make_symbols(characters, phonemes=None, punctuations='!\'(),-.:;? ', pad='_', eos='~', bos='^'):# pylint: disable=redefined-outer-name
|
||||
''' Function to create symbols and phonemes '''
|
||||
def make_symbols(
|
||||
characters, phonemes=None, punctuations="!'(),-.:;? ", pad="_", eos="~", bos="^"
|
||||
): # pylint: disable=redefined-outer-name
|
||||
""" Function to create symbols and phonemes """
|
||||
_symbols = [pad, eos, bos] + list(characters)
|
||||
_phonemes = None
|
||||
if phonemes is not None:
|
||||
_phonemes_sorted = sorted(list(set(phonemes)))
|
||||
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
||||
_arpabet = ['@' + s for s in _phonemes_sorted]
|
||||
_arpabet = ["@" + s for s in _phonemes_sorted]
|
||||
# Export all symbols:
|
||||
_phonemes = [pad, eos, bos] + list(_phonemes_sorted) + list(punctuations)
|
||||
_symbols += _arpabet
|
||||
return _symbols, _phonemes
|
||||
|
||||
_pad = '_'
|
||||
_eos = '~'
|
||||
_bos = '^'
|
||||
_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'(),-.:;? '
|
||||
_punctuations = '!\'(),-.:;? '
|
||||
|
||||
_pad = "_"
|
||||
_eos = "~"
|
||||
_bos = "^"
|
||||
_characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? "
|
||||
_punctuations = "!'(),-.:;? "
|
||||
|
||||
# Phonemes definition (All IPA characters)
|
||||
_vowels = 'iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ'
|
||||
_non_pulmonic_consonants = 'ʘɓǀɗǃʄǂɠǁʛ'
|
||||
_pulmonic_consonants = 'pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ'
|
||||
_suprasegmentals = 'ˈˌːˑ'
|
||||
_other_symbols = 'ʍwɥʜʢʡɕʑɺɧʲ'
|
||||
_diacrilics = 'ɚ˞ɫ'
|
||||
_vowels = "iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ"
|
||||
_non_pulmonic_consonants = "ʘɓǀɗǃʄǂɠǁʛ"
|
||||
_pulmonic_consonants = "pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ"
|
||||
_suprasegmentals = "ˈˌːˑ"
|
||||
_other_symbols = "ʍwɥʜʢʡɕʑɺɧʲ"
|
||||
_diacrilics = "ɚ˞ɫ"
|
||||
_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics
|
||||
|
||||
symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _eos, _bos)
|
||||
|
@ -43,16 +46,18 @@ symbols, phonemes = make_symbols(_characters, _phonemes, _punctuations, _pad, _e
|
|||
|
||||
|
||||
def parse_symbols():
|
||||
return {'pad': _pad,
|
||||
'eos': _eos,
|
||||
'bos': _bos,
|
||||
'characters': _characters,
|
||||
'punctuations': _punctuations,
|
||||
'phonemes': _phonemes}
|
||||
return {
|
||||
"pad": _pad,
|
||||
"eos": _eos,
|
||||
"bos": _bos,
|
||||
"characters": _characters,
|
||||
"punctuations": _punctuations,
|
||||
"phonemes": _phonemes,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
print(" > TTS symbols {}".format(len(symbols)))
|
||||
print(symbols)
|
||||
print(" > TTS phonemes {}".format(len(phonemes)))
|
||||
print(''.join(sorted(phonemes)))
|
||||
print("".join(sorted(phonemes)))
|
||||
|
|
|
@ -3,13 +3,15 @@ import inflect
|
|||
|
||||
_inflect = inflect.engine()
|
||||
|
||||
_time_re = re.compile(r"""\b
|
||||
_time_re = re.compile(
|
||||
r"""\b
|
||||
((0?[0-9])|(1[0-1])|(1[2-9])|(2[0-3])) # hours
|
||||
:
|
||||
([0-5][0-9]) # minutes
|
||||
\s*(a\\.m\\.|am|pm|p\\.m\\.|a\\.m|p\\.m)? # am/pm
|
||||
\b""",
|
||||
re.IGNORECASE | re.X)
|
||||
re.IGNORECASE | re.X,
|
||||
)
|
||||
|
||||
|
||||
def _expand_num(n: int) -> str:
|
||||
|
|
|
@ -3,33 +3,25 @@ import matplotlib
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
matplotlib.use('Agg')
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
from TTS.tts.utils.text import phoneme_to_sequence, sequence_to_phoneme
|
||||
|
||||
|
||||
def plot_alignment(alignment,
|
||||
info=None,
|
||||
fig_size=(16, 10),
|
||||
title=None,
|
||||
output_fig=False):
|
||||
def plot_alignment(alignment, info=None, fig_size=(16, 10), title=None, output_fig=False):
|
||||
if isinstance(alignment, torch.Tensor):
|
||||
alignment_ = alignment.detach().cpu().numpy().squeeze()
|
||||
else:
|
||||
alignment_ = alignment
|
||||
alignment_ = alignment_.astype(
|
||||
np.float32) if alignment_.dtype == np.float16 else alignment_
|
||||
alignment_ = alignment_.astype(np.float32) if alignment_.dtype == np.float16 else alignment_
|
||||
fig, ax = plt.subplots(figsize=fig_size)
|
||||
im = ax.imshow(alignment_.T,
|
||||
aspect='auto',
|
||||
origin='lower',
|
||||
interpolation='none')
|
||||
im = ax.imshow(alignment_.T, aspect="auto", origin="lower", interpolation="none")
|
||||
fig.colorbar(im, ax=ax)
|
||||
xlabel = 'Decoder timestep'
|
||||
xlabel = "Decoder timestep"
|
||||
if info is not None:
|
||||
xlabel += '\n\n' + info
|
||||
xlabel += "\n\n" + info
|
||||
plt.xlabel(xlabel)
|
||||
plt.ylabel('Encoder timestep')
|
||||
plt.ylabel("Encoder timestep")
|
||||
# plt.yticks(range(len(text)), list(text))
|
||||
plt.tight_layout()
|
||||
if title is not None:
|
||||
|
@ -39,16 +31,12 @@ def plot_alignment(alignment,
|
|||
return fig
|
||||
|
||||
|
||||
def plot_spectrogram(spectrogram,
|
||||
ap=None,
|
||||
fig_size=(16, 10),
|
||||
output_fig=False):
|
||||
def plot_spectrogram(spectrogram, ap=None, fig_size=(16, 10), output_fig=False):
|
||||
if isinstance(spectrogram, torch.Tensor):
|
||||
spectrogram_ = spectrogram.detach().cpu().numpy().squeeze().T
|
||||
else:
|
||||
spectrogram_ = spectrogram.T
|
||||
spectrogram_ = spectrogram_.astype(
|
||||
np.float32) if spectrogram_.dtype == np.float16 else spectrogram_
|
||||
spectrogram_ = spectrogram_.astype(np.float32) if spectrogram_.dtype == np.float16 else spectrogram_
|
||||
if ap is not None:
|
||||
spectrogram_ = ap.denormalize(spectrogram_) # pylint: disable=protected-access
|
||||
fig = plt.figure(figsize=fig_size)
|
||||
|
@ -60,16 +48,18 @@ def plot_spectrogram(spectrogram,
|
|||
return fig
|
||||
|
||||
|
||||
def visualize(alignment,
|
||||
postnet_output,
|
||||
text,
|
||||
hop_length,
|
||||
CONFIG,
|
||||
stop_tokens=None,
|
||||
decoder_output=None,
|
||||
output_path=None,
|
||||
figsize=(8, 24),
|
||||
output_fig=False):
|
||||
def visualize(
|
||||
alignment,
|
||||
postnet_output,
|
||||
text,
|
||||
hop_length,
|
||||
CONFIG,
|
||||
stop_tokens=None,
|
||||
decoder_output=None,
|
||||
output_path=None,
|
||||
figsize=(8, 24),
|
||||
output_fig=False,
|
||||
):
|
||||
|
||||
if decoder_output is not None:
|
||||
num_plot = 4
|
||||
|
@ -86,13 +76,13 @@ def visualize(alignment,
|
|||
# compute phoneme representation and back
|
||||
if CONFIG.use_phonemes:
|
||||
seq = phoneme_to_sequence(
|
||||
text, [CONFIG.text_cleaner],
|
||||
text,
|
||||
[CONFIG.text_cleaner],
|
||||
CONFIG.phoneme_language,
|
||||
CONFIG.enable_eos_bos_chars,
|
||||
tp=CONFIG.characters if 'characters' in CONFIG.keys() else None)
|
||||
text = sequence_to_phoneme(
|
||||
seq,
|
||||
tp=CONFIG.characters if 'characters' in CONFIG.keys() else None)
|
||||
tp=CONFIG.characters if "characters" in CONFIG.keys() else None,
|
||||
)
|
||||
text = sequence_to_phoneme(seq, tp=CONFIG.characters if "characters" in CONFIG.keys() else None)
|
||||
print(text)
|
||||
plt.yticks(range(len(text)), list(text))
|
||||
plt.colorbar()
|
||||
|
@ -104,13 +94,15 @@ def visualize(alignment,
|
|||
|
||||
# plot postnet spectrogram
|
||||
plt.subplot(num_plot, 1, 3)
|
||||
librosa.display.specshow(postnet_output.T,
|
||||
sr=CONFIG.audio['sample_rate'],
|
||||
hop_length=hop_length,
|
||||
x_axis="time",
|
||||
y_axis="linear",
|
||||
fmin=CONFIG.audio['mel_fmin'],
|
||||
fmax=CONFIG.audio['mel_fmax'])
|
||||
librosa.display.specshow(
|
||||
postnet_output.T,
|
||||
sr=CONFIG.audio["sample_rate"],
|
||||
hop_length=hop_length,
|
||||
x_axis="time",
|
||||
y_axis="linear",
|
||||
fmin=CONFIG.audio["mel_fmin"],
|
||||
fmax=CONFIG.audio["mel_fmax"],
|
||||
)
|
||||
|
||||
plt.xlabel("Time", fontsize=label_fontsize)
|
||||
plt.ylabel("Hz", fontsize=label_fontsize)
|
||||
|
@ -119,13 +111,15 @@ def visualize(alignment,
|
|||
|
||||
if decoder_output is not None:
|
||||
plt.subplot(num_plot, 1, 4)
|
||||
librosa.display.specshow(decoder_output.T,
|
||||
sr=CONFIG.audio['sample_rate'],
|
||||
hop_length=hop_length,
|
||||
x_axis="time",
|
||||
y_axis="linear",
|
||||
fmin=CONFIG.audio['mel_fmin'],
|
||||
fmax=CONFIG.audio['mel_fmax'])
|
||||
librosa.display.specshow(
|
||||
decoder_output.T,
|
||||
sr=CONFIG.audio["sample_rate"],
|
||||
hop_length=hop_length,
|
||||
x_axis="time",
|
||||
y_axis="linear",
|
||||
fmin=CONFIG.audio["mel_fmin"],
|
||||
fmax=CONFIG.audio["mel_fmax"],
|
||||
)
|
||||
plt.xlabel("Time", fontsize=label_fontsize)
|
||||
plt.ylabel("Hz", fontsize=label_fontsize)
|
||||
plt.tight_layout()
|
||||
|
|
|
@ -29,41 +29,31 @@ def parse_arguments(argv):
|
|||
parser.add_argument(
|
||||
"--continue_path",
|
||||
type=str,
|
||||
help=("Training output folder to continue training. Used to continue "
|
||||
"a training. If it is used, 'config_path' is ignored."),
|
||||
help=(
|
||||
"Training output folder to continue training. Used to continue "
|
||||
"a training. If it is used, 'config_path' is ignored."
|
||||
),
|
||||
default="",
|
||||
required="--config_path" not in argv)
|
||||
required="--config_path" not in argv,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--restore_path",
|
||||
type=str,
|
||||
help="Model file to be restored. Use to finetune a model.",
|
||||
default="")
|
||||
"--restore_path", type=str, help="Model file to be restored. Use to finetune a model.", default=""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--best_path",
|
||||
type=str,
|
||||
help=("Best model file to be used for extracting best loss."
|
||||
"If not specified, the latest best model in continue path is used"),
|
||||
default="")
|
||||
parser.add_argument(
|
||||
"--config_path",
|
||||
type=str,
|
||||
help="Path to config file for training.",
|
||||
required="--continue_path" not in argv)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Do not verify commit integrity to run training.")
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
type=int,
|
||||
default=0,
|
||||
help="DISTRIBUTED: process rank for distributed training.")
|
||||
parser.add_argument(
|
||||
"--group_id",
|
||||
type=str,
|
||||
help=(
|
||||
"Best model file to be used for extracting best loss."
|
||||
"If not specified, the latest best model in continue path is used"
|
||||
),
|
||||
default="",
|
||||
help="DISTRIBUTED: process group id.")
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_path", type=str, help="Path to config file for training.", required="--continue_path" not in argv
|
||||
)
|
||||
parser.add_argument("--debug", type=bool, default=False, help="Do not verify commit integrity to run training.")
|
||||
parser.add_argument("--rank", type=int, default=0, help="DISTRIBUTED: process rank for distributed training.")
|
||||
parser.add_argument("--group_id", type=str, default="", help="DISTRIBUTED: process group id.")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
@ -86,7 +76,7 @@ def get_last_checkpoint(path):
|
|||
file_names = glob.glob(os.path.join(path, "*.pth.tar"))
|
||||
last_models = {}
|
||||
last_model_nums = {}
|
||||
for key in ['checkpoint', 'best_model']:
|
||||
for key in ["checkpoint", "best_model"]:
|
||||
last_model_num = None
|
||||
last_model = None
|
||||
# pass all the checkpoint files and find
|
||||
|
@ -105,7 +95,7 @@ def get_last_checkpoint(path):
|
|||
key_file_names = [fn for fn in file_names if key in fn]
|
||||
if last_model is None and len(key_file_names) > 0:
|
||||
last_model = max(key_file_names, key=os.path.getctime)
|
||||
last_model_num = torch.load(last_model)['step']
|
||||
last_model_num = torch.load(last_model)["step"]
|
||||
|
||||
if last_model is not None:
|
||||
last_models[key] = last_model
|
||||
|
@ -114,16 +104,16 @@ def get_last_checkpoint(path):
|
|||
# check what models were found
|
||||
if not last_models:
|
||||
raise ValueError(f"No models found in continue path {path}!")
|
||||
if 'checkpoint' not in last_models: # no checkpoint just best model
|
||||
last_models['checkpoint'] = last_models['best_model']
|
||||
elif 'best_model' not in last_models: # no best model
|
||||
if "checkpoint" not in last_models: # no checkpoint just best model
|
||||
last_models["checkpoint"] = last_models["best_model"]
|
||||
elif "best_model" not in last_models: # no best model
|
||||
# this shouldn't happen, but let's handle it just in case
|
||||
last_models['best_model'] = None
|
||||
last_models["best_model"] = None
|
||||
# finally check if last best model is more recent than checkpoint
|
||||
elif last_model_nums['best_model'] > last_model_nums['checkpoint']:
|
||||
last_models['checkpoint'] = last_models['best_model']
|
||||
elif last_model_nums["best_model"] > last_model_nums["checkpoint"]:
|
||||
last_models["checkpoint"] = last_models["best_model"]
|
||||
|
||||
return last_models['checkpoint'], last_models['best_model']
|
||||
return last_models["checkpoint"], last_models["best_model"]
|
||||
|
||||
|
||||
def process_args(args, model_class):
|
||||
|
@ -157,13 +147,12 @@ def process_args(args, model_class):
|
|||
c = load_config(args.config_path)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
if 'mixed_precision' in c and c.mixed_precision:
|
||||
if "mixed_precision" in c and c.mixed_precision:
|
||||
print(" > Mixed precision mode is ON")
|
||||
|
||||
out_path = args.continue_path
|
||||
if not out_path:
|
||||
out_path = create_experiment_folder(c.output_path, c.run_name,
|
||||
args.debug)
|
||||
out_path = create_experiment_folder(c.output_path, c.run_name, args.debug)
|
||||
|
||||
audio_path = os.path.join(out_path, "test_audios")
|
||||
|
||||
|
@ -179,11 +168,10 @@ def process_args(args, model_class):
|
|||
# if model characters are not set in the config file
|
||||
# save the default set to the config file for future
|
||||
# compatibility.
|
||||
if model_class == 'tts' and 'characters' not in c:
|
||||
if model_class == "tts" and "characters" not in c:
|
||||
used_characters = parse_symbols()
|
||||
new_fields['characters'] = used_characters
|
||||
copy_model_files(c, args.config_path,
|
||||
out_path, new_fields)
|
||||
new_fields["characters"] = used_characters
|
||||
copy_model_files(c, args.config_path, out_path, new_fields)
|
||||
os.chmod(audio_path, 0o775)
|
||||
os.chmod(out_path, 0o775)
|
||||
|
||||
|
|
|
@ -3,11 +3,12 @@ import soundfile as sf
|
|||
import numpy as np
|
||||
import scipy.io.wavfile
|
||||
import scipy.signal
|
||||
|
||||
# import pyworld as pw
|
||||
|
||||
from TTS.tts.utils.data import StandardScaler
|
||||
|
||||
#pylint: disable=too-many-public-methods
|
||||
# pylint: disable=too-many-public-methods
|
||||
class AudioProcessor(object):
|
||||
"""Audio Processor for TTS used by all the data pipelines.
|
||||
|
||||
|
@ -43,35 +44,38 @@ class AudioProcessor(object):
|
|||
stats_path (str, optional): Path to the computed stats file. Defaults to None.
|
||||
verbose (bool, optional): enable/disable logging. Defaults to True.
|
||||
"""
|
||||
def __init__(self,
|
||||
sample_rate=None,
|
||||
resample=False,
|
||||
num_mels=None,
|
||||
log_func='np.log10',
|
||||
min_level_db=None,
|
||||
frame_shift_ms=None,
|
||||
frame_length_ms=None,
|
||||
hop_length=None,
|
||||
win_length=None,
|
||||
ref_level_db=None,
|
||||
fft_size=1024,
|
||||
power=None,
|
||||
preemphasis=0.0,
|
||||
signal_norm=None,
|
||||
symmetric_norm=None,
|
||||
max_norm=None,
|
||||
mel_fmin=None,
|
||||
mel_fmax=None,
|
||||
spec_gain=20,
|
||||
stft_pad_mode='reflect',
|
||||
clip_norm=True,
|
||||
griffin_lim_iters=None,
|
||||
do_trim_silence=False,
|
||||
trim_db=60,
|
||||
do_sound_norm=False,
|
||||
stats_path=None,
|
||||
verbose=True,
|
||||
**_):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate=None,
|
||||
resample=False,
|
||||
num_mels=None,
|
||||
log_func="np.log10",
|
||||
min_level_db=None,
|
||||
frame_shift_ms=None,
|
||||
frame_length_ms=None,
|
||||
hop_length=None,
|
||||
win_length=None,
|
||||
ref_level_db=None,
|
||||
fft_size=1024,
|
||||
power=None,
|
||||
preemphasis=0.0,
|
||||
signal_norm=None,
|
||||
symmetric_norm=None,
|
||||
max_norm=None,
|
||||
mel_fmin=None,
|
||||
mel_fmax=None,
|
||||
spec_gain=20,
|
||||
stft_pad_mode="reflect",
|
||||
clip_norm=True,
|
||||
griffin_lim_iters=None,
|
||||
do_trim_silence=False,
|
||||
trim_db=60,
|
||||
do_sound_norm=False,
|
||||
stats_path=None,
|
||||
verbose=True,
|
||||
**_,
|
||||
):
|
||||
|
||||
# setup class attributed
|
||||
self.sample_rate = sample_rate
|
||||
|
@ -98,14 +102,14 @@ class AudioProcessor(object):
|
|||
self.do_sound_norm = do_sound_norm
|
||||
self.stats_path = stats_path
|
||||
# setup exp_func for db to amp conversion
|
||||
print(f'self.log_func = {log_func}')
|
||||
exec(f'self.log_func = {log_func}') #pylint: disable=exec-used
|
||||
if self.log_func.__name__ == 'log':
|
||||
print(f"self.log_func = {log_func}")
|
||||
exec(f"self.log_func = {log_func}") # pylint: disable=exec-used
|
||||
if self.log_func.__name__ == "log":
|
||||
self.exp_func = np.exp
|
||||
elif self.log_func.__name__ == 'log10':
|
||||
elif self.log_func.__name__ == "log10":
|
||||
self.exp_func = lambda x: 10 ** x
|
||||
else:
|
||||
raise ValueError(' [!] unknown `log_func` value.')
|
||||
raise ValueError(" [!] unknown `log_func` value.")
|
||||
# setup stft parameters
|
||||
if hop_length is None:
|
||||
# compute stft parameters from given time values
|
||||
|
@ -134,17 +138,18 @@ class AudioProcessor(object):
|
|||
self.symmetric_norm = None
|
||||
|
||||
### setting up the parameters ###
|
||||
def _build_mel_basis(self, ):
|
||||
def _build_mel_basis(
|
||||
self,
|
||||
):
|
||||
if self.mel_fmax is not None:
|
||||
assert self.mel_fmax <= self.sample_rate // 2
|
||||
return librosa.filters.mel(
|
||||
self.sample_rate,
|
||||
self.fft_size,
|
||||
n_mels=self.num_mels,
|
||||
fmin=self.mel_fmin,
|
||||
fmax=self.mel_fmax)
|
||||
self.sample_rate, self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
|
||||
)
|
||||
|
||||
def _stft_parameters(self, ):
|
||||
def _stft_parameters(
|
||||
self,
|
||||
):
|
||||
"""Compute necessary stft parameters with given time values"""
|
||||
factor = self.frame_length_ms / self.frame_shift_ms
|
||||
assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms"
|
||||
|
@ -155,24 +160,26 @@ class AudioProcessor(object):
|
|||
### normalization ###
|
||||
def normalize(self, S):
|
||||
"""Put values in [0, self.max_norm] or [-self.max_norm, self.max_norm]"""
|
||||
#pylint: disable=no-else-return
|
||||
# pylint: disable=no-else-return
|
||||
S = S.copy()
|
||||
if self.signal_norm:
|
||||
# mean-var scaling
|
||||
if hasattr(self, 'mel_scaler'):
|
||||
if hasattr(self, "mel_scaler"):
|
||||
if S.shape[0] == self.num_mels:
|
||||
return self.mel_scaler.transform(S.T).T
|
||||
elif S.shape[0] == self.fft_size / 2:
|
||||
return self.linear_scaler.transform(S.T).T
|
||||
else:
|
||||
raise RuntimeError(' [!] Mean-Var stats does not match the given feature dimensions.')
|
||||
raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.")
|
||||
# range normalization
|
||||
S -= self.ref_level_db # discard certain range of DB assuming it is air noise
|
||||
S_norm = ((S - self.min_level_db) / (-self.min_level_db))
|
||||
S_norm = (S - self.min_level_db) / (-self.min_level_db)
|
||||
if self.symmetric_norm:
|
||||
S_norm = ((2 * self.max_norm) * S_norm) - self.max_norm
|
||||
if self.clip_norm:
|
||||
S_norm = np.clip(S_norm, -self.max_norm, self.max_norm) # pylint: disable=invalid-unary-operand-type
|
||||
S_norm = np.clip(
|
||||
S_norm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type
|
||||
)
|
||||
return S_norm
|
||||
else:
|
||||
S_norm = self.max_norm * S_norm
|
||||
|
@ -184,47 +191,49 @@ class AudioProcessor(object):
|
|||
|
||||
def denormalize(self, S):
|
||||
"""denormalize values"""
|
||||
#pylint: disable=no-else-return
|
||||
# pylint: disable=no-else-return
|
||||
S_denorm = S.copy()
|
||||
if self.signal_norm:
|
||||
# mean-var scaling
|
||||
if hasattr(self, 'mel_scaler'):
|
||||
if hasattr(self, "mel_scaler"):
|
||||
if S_denorm.shape[0] == self.num_mels:
|
||||
return self.mel_scaler.inverse_transform(S_denorm.T).T
|
||||
elif S_denorm.shape[0] == self.fft_size / 2:
|
||||
return self.linear_scaler.inverse_transform(S_denorm.T).T
|
||||
else:
|
||||
raise RuntimeError(' [!] Mean-Var stats does not match the given feature dimensions.')
|
||||
raise RuntimeError(" [!] Mean-Var stats does not match the given feature dimensions.")
|
||||
if self.symmetric_norm:
|
||||
if self.clip_norm:
|
||||
S_denorm = np.clip(S_denorm, -self.max_norm, self.max_norm) #pylint: disable=invalid-unary-operand-type
|
||||
S_denorm = np.clip(
|
||||
S_denorm, -self.max_norm, self.max_norm # pylint: disable=invalid-unary-operand-type
|
||||
)
|
||||
S_denorm = ((S_denorm + self.max_norm) * -self.min_level_db / (2 * self.max_norm)) + self.min_level_db
|
||||
return S_denorm + self.ref_level_db
|
||||
else:
|
||||
if self.clip_norm:
|
||||
S_denorm = np.clip(S_denorm, 0, self.max_norm)
|
||||
S_denorm = (S_denorm * -self.min_level_db /
|
||||
self.max_norm) + self.min_level_db
|
||||
S_denorm = (S_denorm * -self.min_level_db / self.max_norm) + self.min_level_db
|
||||
return S_denorm + self.ref_level_db
|
||||
else:
|
||||
return S_denorm
|
||||
|
||||
### Mean-STD scaling ###
|
||||
def load_stats(self, stats_path):
|
||||
stats = np.load(stats_path, allow_pickle=True).item() #pylint: disable=unexpected-keyword-arg
|
||||
mel_mean = stats['mel_mean']
|
||||
mel_std = stats['mel_std']
|
||||
linear_mean = stats['linear_mean']
|
||||
linear_std = stats['linear_std']
|
||||
stats_config = stats['audio_config']
|
||||
stats = np.load(stats_path, allow_pickle=True).item() # pylint: disable=unexpected-keyword-arg
|
||||
mel_mean = stats["mel_mean"]
|
||||
mel_std = stats["mel_std"]
|
||||
linear_mean = stats["linear_mean"]
|
||||
linear_std = stats["linear_std"]
|
||||
stats_config = stats["audio_config"]
|
||||
# check all audio parameters used for computing stats
|
||||
skip_parameters = ['griffin_lim_iters', 'stats_path', 'do_trim_silence', 'ref_level_db', 'power']
|
||||
skip_parameters = ["griffin_lim_iters", "stats_path", "do_trim_silence", "ref_level_db", "power"]
|
||||
for key in stats_config.keys():
|
||||
if key in skip_parameters:
|
||||
continue
|
||||
if key not in ['sample_rate', 'trim_db']:
|
||||
assert stats_config[key] == self.__dict__[key],\
|
||||
f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
|
||||
if key not in ["sample_rate", "trim_db"]:
|
||||
assert (
|
||||
stats_config[key] == self.__dict__[key]
|
||||
), f" [!] Audio param {key} does not match the value used for computing mean-var stats. {stats_config[key]} vs {self.__dict__[key]}"
|
||||
return mel_mean, mel_std, linear_mean, linear_std, stats_config
|
||||
|
||||
# pylint: disable=attribute-defined-outside-init
|
||||
|
@ -243,7 +252,6 @@ class AudioProcessor(object):
|
|||
def _db_to_amp(self, x):
|
||||
return self.exp_func(x / self.spec_gain)
|
||||
|
||||
|
||||
### Preemphasis ###
|
||||
def apply_preemphasis(self, x):
|
||||
if self.preemphasis == 0:
|
||||
|
@ -284,17 +292,17 @@ class AudioProcessor(object):
|
|||
S = self._db_to_amp(S)
|
||||
# Reconstruct phase
|
||||
if self.preemphasis != 0:
|
||||
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
||||
return self._griffin_lim(S**self.power)
|
||||
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
||||
return self._griffin_lim(S ** self.power)
|
||||
|
||||
def inv_melspectrogram(self, mel_spectrogram):
|
||||
'''Converts melspectrogram to waveform using librosa'''
|
||||
"""Converts melspectrogram to waveform using librosa"""
|
||||
D = self.denormalize(mel_spectrogram)
|
||||
S = self._db_to_amp(D)
|
||||
S = self._mel_to_linear(S) # Convert back to linear
|
||||
if self.preemphasis != 0:
|
||||
return self.apply_inv_preemphasis(self._griffin_lim(S**self.power))
|
||||
return self._griffin_lim(S**self.power)
|
||||
return self.apply_inv_preemphasis(self._griffin_lim(S ** self.power))
|
||||
return self._griffin_lim(S ** self.power)
|
||||
|
||||
def out_linear_to_mel(self, linear_spec):
|
||||
S = self.denormalize(linear_spec)
|
||||
|
@ -315,8 +323,7 @@ class AudioProcessor(object):
|
|||
)
|
||||
|
||||
def _istft(self, y):
|
||||
return librosa.istft(
|
||||
y, hop_length=self.hop_length, win_length=self.win_length)
|
||||
return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length)
|
||||
|
||||
def _griffin_lim(self, S):
|
||||
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
||||
|
@ -328,8 +335,7 @@ class AudioProcessor(object):
|
|||
return y
|
||||
|
||||
def compute_stft_paddings(self, x, pad_sides=1):
|
||||
'''compute right padding (final frame) or both sides padding (first and final frames)
|
||||
'''
|
||||
"""compute right padding (final frame) or both sides padding (first and final frames)"""
|
||||
assert pad_sides in (1, 2)
|
||||
pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0]
|
||||
if pad_sides == 1:
|
||||
|
@ -354,7 +360,7 @@ class AudioProcessor(object):
|
|||
hop_length = int(window_length / 4)
|
||||
threshold = self._db_to_amp(threshold_db)
|
||||
for x in range(hop_length, len(wav) - window_length, hop_length):
|
||||
if np.max(wav[x:x + window_length]) < threshold:
|
||||
if np.max(wav[x : x + window_length]) < threshold:
|
||||
return x + hop_length
|
||||
return len(wav)
|
||||
|
||||
|
@ -362,8 +368,9 @@ class AudioProcessor(object):
|
|||
""" Trim silent parts with a threshold and 0.01 sec margin """
|
||||
margin = int(self.sample_rate * 0.01)
|
||||
wav = wav[margin:-margin]
|
||||
return librosa.effects.trim(
|
||||
wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[0]
|
||||
return librosa.effects.trim(wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[
|
||||
0
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def sound_norm(x):
|
||||
|
@ -375,14 +382,14 @@ class AudioProcessor(object):
|
|||
x, sr = librosa.load(filename, sr=self.sample_rate)
|
||||
elif sr is None:
|
||||
x, sr = sf.read(filename)
|
||||
assert self.sample_rate == sr, "%s vs %s"%(self.sample_rate, sr)
|
||||
assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr)
|
||||
else:
|
||||
x, sr = librosa.load(filename, sr=sr)
|
||||
if self.do_trim_silence:
|
||||
try:
|
||||
x = self.trim_silence(x)
|
||||
except ValueError:
|
||||
print(f' [!] File cannot be trimmed for silence - {filename}')
|
||||
print(f" [!] File cannot be trimmed for silence - {filename}")
|
||||
if self.do_sound_norm:
|
||||
x = self.sound_norm(x)
|
||||
return x
|
||||
|
@ -396,10 +403,12 @@ class AudioProcessor(object):
|
|||
def mulaw_encode(wav, qc):
|
||||
mu = 2 ** qc - 1
|
||||
# wav_abs = np.minimum(np.abs(wav), 1.0)
|
||||
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1. + mu)
|
||||
signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu)
|
||||
# Quantize signal to the specified number of levels.
|
||||
signal = (signal + 1) / 2 * mu + 0.5
|
||||
return np.floor(signal,)
|
||||
return np.floor(
|
||||
signal,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mulaw_decode(wav, qc):
|
||||
|
@ -408,15 +417,14 @@ class AudioProcessor(object):
|
|||
x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1)
|
||||
return x
|
||||
|
||||
|
||||
@staticmethod
|
||||
def encode_16bits(x):
|
||||
return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16)
|
||||
return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16)
|
||||
|
||||
@staticmethod
|
||||
def quantize(x, bits):
|
||||
return (x + 1.) * (2**bits - 1) / 2
|
||||
return (x + 1.0) * (2 ** bits - 1) / 2
|
||||
|
||||
@staticmethod
|
||||
def dequantize(x, bits):
|
||||
return 2 * x / (2**bits - 1) - 1
|
||||
return 2 * x / (2 ** bits - 1) - 1
|
||||
|
|
|
@ -2,19 +2,21 @@ import datetime
|
|||
from TTS.utils.io import AttrDict
|
||||
|
||||
|
||||
tcolors = AttrDict({
|
||||
'OKBLUE': '\033[94m',
|
||||
'HEADER': '\033[95m',
|
||||
'OKGREEN': '\033[92m',
|
||||
'WARNING': '\033[93m',
|
||||
'FAIL': '\033[91m',
|
||||
'ENDC': '\033[0m',
|
||||
'BOLD': '\033[1m',
|
||||
'UNDERLINE': '\033[4m'
|
||||
})
|
||||
tcolors = AttrDict(
|
||||
{
|
||||
"OKBLUE": "\033[94m",
|
||||
"HEADER": "\033[95m",
|
||||
"OKGREEN": "\033[92m",
|
||||
"WARNING": "\033[93m",
|
||||
"FAIL": "\033[91m",
|
||||
"ENDC": "\033[0m",
|
||||
"BOLD": "\033[1m",
|
||||
"UNDERLINE": "\033[4m",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ConsoleLogger():
|
||||
class ConsoleLogger:
|
||||
def __init__(self):
|
||||
# TODO: color code for value changes
|
||||
# use these to compare values between iterations
|
||||
|
@ -28,23 +30,24 @@ class ConsoleLogger():
|
|||
return now.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
def print_epoch_start(self, epoch, max_epoch):
|
||||
print("\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD,
|
||||
epoch, max_epoch, tcolors.ENDC),
|
||||
flush=True)
|
||||
print(
|
||||
"\n{}{} > EPOCH: {}/{}{}".format(tcolors.UNDERLINE, tcolors.BOLD, epoch, max_epoch, tcolors.ENDC),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
def print_train_start(self):
|
||||
print(f"\n{tcolors.BOLD} > TRAINING ({self.get_time()}) {tcolors.ENDC}")
|
||||
|
||||
def print_train_step(self, batch_steps, step, global_step, log_dict,
|
||||
loss_dict, avg_loss_dict):
|
||||
def print_train_step(self, batch_steps, step, global_step, log_dict, loss_dict, avg_loss_dict):
|
||||
indent = " | > "
|
||||
print()
|
||||
log_text = "{} --> STEP: {}/{} -- GLOBAL_STEP: {}{}\n".format(
|
||||
tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC)
|
||||
tcolors.BOLD, step, batch_steps, global_step, tcolors.ENDC
|
||||
)
|
||||
for key, value in loss_dict.items():
|
||||
# print the avg value if given
|
||||
if f'avg_{key}' in avg_loss_dict.keys():
|
||||
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
|
||||
if f"avg_{key}" in avg_loss_dict.keys():
|
||||
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"])
|
||||
else:
|
||||
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
|
||||
for idx, (key, value) in enumerate(log_dict.items()):
|
||||
|
@ -52,13 +55,12 @@ class ConsoleLogger():
|
|||
log_text += f"{indent}{key}: {value[0]:.{value[1]}f}"
|
||||
else:
|
||||
log_text += f"{indent}{key}: {value}"
|
||||
if idx < len(log_dict)-1:
|
||||
if idx < len(log_dict) - 1:
|
||||
log_text += "\n"
|
||||
print(log_text, flush=True)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def print_train_epoch_end(self, global_step, epoch, epoch_time,
|
||||
print_dict):
|
||||
def print_train_epoch_end(self, global_step, epoch, epoch_time, print_dict):
|
||||
indent = " | > "
|
||||
log_text = f"\n{tcolors.BOLD} --> TRAIN PERFORMACE -- EPOCH TIME: {epoch_time:.2f} sec -- GLOBAL_STEP: {global_step}{tcolors.ENDC}\n"
|
||||
for key, value in print_dict.items():
|
||||
|
@ -74,29 +76,28 @@ class ConsoleLogger():
|
|||
log_text = f"{tcolors.BOLD} --> STEP: {step}{tcolors.ENDC}\n"
|
||||
for key, value in loss_dict.items():
|
||||
# print the avg value if given
|
||||
if f'avg_{key}' in avg_loss_dict.keys():
|
||||
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f'avg_{key}'])
|
||||
if f"avg_{key}" in avg_loss_dict.keys():
|
||||
log_text += "{}{}: {:.5f} ({:.5f})\n".format(indent, key, value, avg_loss_dict[f"avg_{key}"])
|
||||
else:
|
||||
log_text += "{}{}: {:.5f} \n".format(indent, key, value)
|
||||
print(log_text, flush=True)
|
||||
|
||||
def print_epoch_end(self, epoch, avg_loss_dict):
|
||||
indent = " | > "
|
||||
log_text = " {}--> EVAL PERFORMANCE{}\n".format(
|
||||
tcolors.BOLD, tcolors.ENDC)
|
||||
log_text = " {}--> EVAL PERFORMANCE{}\n".format(tcolors.BOLD, tcolors.ENDC)
|
||||
for key, value in avg_loss_dict.items():
|
||||
# print the avg value if given
|
||||
color = ''
|
||||
sign = '+'
|
||||
color = ""
|
||||
sign = "+"
|
||||
diff = 0
|
||||
if self.old_eval_loss_dict is not None and key in self.old_eval_loss_dict:
|
||||
diff = value - self.old_eval_loss_dict[key]
|
||||
if diff < 0:
|
||||
color = tcolors.OKGREEN
|
||||
sign = ''
|
||||
sign = ""
|
||||
elif diff > 0:
|
||||
color = tcolors.FAIL
|
||||
sign = '+'
|
||||
sign = "+"
|
||||
log_text += "{}{}:{} {:.5f} {}({}{:.5f})\n".format(indent, key, color, value, tcolors.ENDC, sign, diff)
|
||||
self.old_eval_loss_dict = avg_loss_dict
|
||||
print(log_text, flush=True)
|
||||
|
|
|
@ -14,7 +14,7 @@ class DistributedSampler(Sampler):
|
|||
"""
|
||||
|
||||
def __init__(self, dataset, num_replicas=None, rank=None):
|
||||
super(DistributedSampler, self).__init__(dataset)
|
||||
super().__init__(dataset)
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
|
@ -34,11 +34,11 @@ class DistributedSampler(Sampler):
|
|||
indices = torch.arange(len(self.dataset)).tolist()
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
indices += indices[: (self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
@ -64,12 +64,7 @@ def init_distributed(rank, num_gpus, group_name, dist_backend, dist_url):
|
|||
torch.cuda.set_device(rank % torch.cuda.device_count())
|
||||
|
||||
# Initialize distributed communication
|
||||
dist.init_process_group(
|
||||
dist_backend,
|
||||
init_method=dist_url,
|
||||
world_size=num_gpus,
|
||||
rank=rank,
|
||||
group_name=group_name)
|
||||
dist.init_process_group(dist_backend, init_method=dist_url, world_size=num_gpus, rank=rank, group_name=group_name)
|
||||
|
||||
|
||||
def apply_gradient_allreduce(module):
|
||||
|
@ -97,14 +92,13 @@ def apply_gradient_allreduce(module):
|
|||
coalesced = _flatten_dense_tensors(grads)
|
||||
dist.all_reduce(coalesced, op=dist.reduce_op.SUM)
|
||||
coalesced /= dist.get_world_size()
|
||||
for buf, synced in zip(
|
||||
grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
||||
buf.copy_(synced)
|
||||
|
||||
for param in list(module.parameters()):
|
||||
|
||||
def allreduce_hook(*_):
|
||||
Variable._execution_engine.queue_callback(allreduce_params) #pylint: disable=protected-access
|
||||
Variable._execution_engine.queue_callback(allreduce_params) # pylint: disable=protected-access
|
||||
|
||||
if param.requires_grad:
|
||||
param.register_hook(allreduce_hook)
|
||||
|
|
|
@ -10,8 +10,7 @@ from pathlib import Path
|
|||
def get_git_branch():
|
||||
try:
|
||||
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
||||
current = next(line for line in out.split("\n")
|
||||
if line.startswith("*"))
|
||||
current = next(line for line in out.split("\n") if line.startswith("*"))
|
||||
current.replace("* ", "")
|
||||
except subprocess.CalledProcessError:
|
||||
current = "inside_docker"
|
||||
|
@ -29,12 +28,11 @@ def get_commit_hash():
|
|||
# raise RuntimeError(
|
||||
# " !! Commit before training to get the commit hash.")
|
||||
try:
|
||||
commit = subprocess.check_output(
|
||||
['git', 'rev-parse', '--short', 'HEAD']).decode().strip()
|
||||
commit = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode().strip()
|
||||
# Not copying .git folder into docker container
|
||||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||||
commit = "0000000"
|
||||
print(' > Git Hash: {}'.format(commit))
|
||||
print(" > Git Hash: {}".format(commit))
|
||||
return commit
|
||||
|
||||
|
||||
|
@ -42,11 +40,10 @@ def create_experiment_folder(root_path, model_name, debug):
|
|||
""" Create a folder with the current date and time """
|
||||
date_str = datetime.datetime.now().strftime("%B-%d-%Y_%I+%M%p")
|
||||
if debug:
|
||||
commit_hash = 'debug'
|
||||
commit_hash = "debug"
|
||||
else:
|
||||
commit_hash = get_commit_hash()
|
||||
output_folder = os.path.join(
|
||||
root_path, model_name + '-' + date_str + '-' + commit_hash)
|
||||
output_folder = os.path.join(root_path, model_name + "-" + date_str + "-" + commit_hash)
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
print(" > Experiment folder: {}".format(output_folder))
|
||||
return output_folder
|
||||
|
@ -72,16 +69,16 @@ def count_parameters(model):
|
|||
def get_user_data_dir(appname):
|
||||
if sys.platform == "win32":
|
||||
import winreg # pylint: disable=import-outside-toplevel
|
||||
|
||||
key = winreg.OpenKey(
|
||||
winreg.HKEY_CURRENT_USER,
|
||||
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
|
||||
winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders"
|
||||
)
|
||||
dir_, _ = winreg.QueryValueEx(key, "Local AppData")
|
||||
ans = Path(dir_).resolve(strict=False)
|
||||
elif sys.platform == 'darwin':
|
||||
ans = Path('~/Library/Application Support/').expanduser()
|
||||
elif sys.platform == "darwin":
|
||||
ans = Path("~/Library/Application Support/").expanduser()
|
||||
else:
|
||||
ans = Path.home().joinpath('.local/share')
|
||||
ans = Path.home().joinpath(".local/share")
|
||||
return ans.joinpath(appname)
|
||||
|
||||
|
||||
|
@ -91,32 +88,20 @@ def set_init_dict(model_dict, checkpoint_state, c):
|
|||
if k not in model_dict:
|
||||
print(" | > Layer missing in the model definition: {}".format(k))
|
||||
# 1. filter out unnecessary keys
|
||||
pretrained_dict = {
|
||||
k: v
|
||||
for k, v in checkpoint_state.items() if k in model_dict
|
||||
}
|
||||
pretrained_dict = {k: v for k, v in checkpoint_state.items() if k in model_dict}
|
||||
# 2. filter out different size layers
|
||||
pretrained_dict = {
|
||||
k: v
|
||||
for k, v in pretrained_dict.items()
|
||||
if v.numel() == model_dict[k].numel()
|
||||
}
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if v.numel() == model_dict[k].numel()}
|
||||
# 3. skip reinit layers
|
||||
if c.reinit_layers is not None:
|
||||
for reinit_layer_name in c.reinit_layers:
|
||||
pretrained_dict = {
|
||||
k: v
|
||||
for k, v in pretrained_dict.items()
|
||||
if reinit_layer_name not in k
|
||||
}
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if reinit_layer_name not in k}
|
||||
# 4. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
print(" | > {} / {} layers are restored.".format(len(pretrained_dict),
|
||||
len(model_dict)))
|
||||
print(" | > {} / {} layers are restored.".format(len(pretrained_dict), len(model_dict)))
|
||||
return model_dict
|
||||
|
||||
|
||||
class KeepAverage():
|
||||
class KeepAverage:
|
||||
def __init__(self):
|
||||
self.avg_values = {}
|
||||
self.iters = {}
|
||||
|
@ -141,8 +126,7 @@ class KeepAverage():
|
|||
self.avg_values[name] = 0.99 * self.avg_values[name] + 0.01 * value
|
||||
self.iters[name] += 1
|
||||
else:
|
||||
self.avg_values[name] = self.avg_values[name] * \
|
||||
self.iters[name] + value
|
||||
self.avg_values[name] = self.avg_values[name] * self.iters[name] + value
|
||||
self.iters[name] += 1
|
||||
self.avg_values[name] /= self.iters[name]
|
||||
|
||||
|
@ -155,23 +139,27 @@ class KeepAverage():
|
|||
self.update_value(key, value)
|
||||
|
||||
|
||||
def check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None):
|
||||
def check_argument(
|
||||
name, c, enum_list=None, max_val=None, min_val=None, restricted=False, val_type=None, alternative=None
|
||||
):
|
||||
if alternative in c.keys() and c[alternative] is not None:
|
||||
return
|
||||
if restricted:
|
||||
assert name in c.keys(), f' [!] {name} not defined in config.json'
|
||||
assert name in c.keys(), f" [!] {name} not defined in config.json"
|
||||
if name in c.keys():
|
||||
if max_val:
|
||||
assert c[name] <= max_val, f' [!] {name} is larger than max value {max_val}'
|
||||
assert c[name] <= max_val, f" [!] {name} is larger than max value {max_val}"
|
||||
if min_val:
|
||||
assert c[name] >= min_val, f' [!] {name} is smaller than min value {min_val}'
|
||||
assert c[name] >= min_val, f" [!] {name} is smaller than min value {min_val}"
|
||||
if enum_list:
|
||||
assert c[name].lower() in enum_list, f' [!] {name} is not a valid value'
|
||||
assert c[name].lower() in enum_list, f" [!] {name} is not a valid value"
|
||||
if isinstance(val_type, list):
|
||||
is_valid = False
|
||||
for typ in val_type:
|
||||
if isinstance(c[name], typ):
|
||||
is_valid = True
|
||||
assert is_valid or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
|
||||
assert is_valid or c[name] is None, f" [!] {name} has wrong type - {type(c[name])} vs {val_type}"
|
||||
elif val_type:
|
||||
assert isinstance(c[name], val_type) or c[name] is None, f' [!] {name} has wrong type - {type(c[name])} vs {val_type}'
|
||||
assert (
|
||||
isinstance(c[name], val_type) or c[name] is None
|
||||
), f" [!] {name} has wrong type - {type(c[name])} vs {val_type}"
|
||||
|
|
|
@ -8,15 +8,17 @@ from shutil import copyfile
|
|||
|
||||
class RenamingUnpickler(pickle_tts.Unpickler):
|
||||
"""Overload default pickler to solve module renaming problem"""
|
||||
|
||||
def find_class(self, module, name):
|
||||
return super().find_class(module.replace('mozilla_voice_tts', 'TTS'), name)
|
||||
return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name)
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
"""A custom dict which converts dict keys
|
||||
to class attributes"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
|
@ -25,11 +27,12 @@ def read_json_with_comments(json_path):
|
|||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
input_str = f.read()
|
||||
# handle comments
|
||||
input_str = re.sub(r'\\\n', '', input_str)
|
||||
input_str = re.sub(r'//.*\n', '\n', input_str)
|
||||
input_str = re.sub(r"\\\n", "", input_str)
|
||||
input_str = re.sub(r"//.*\n", "\n", input_str)
|
||||
data = json.loads(input_str)
|
||||
return data
|
||||
|
||||
|
||||
def load_config(config_path: str) -> AttrDict:
|
||||
"""Load config files and discard comments
|
||||
|
||||
|
@ -60,7 +63,7 @@ def copy_model_files(c, config_file, out_path, new_fields):
|
|||
in the config file.
|
||||
"""
|
||||
# copy config.json
|
||||
copy_config_path = os.path.join(out_path, 'config.json')
|
||||
copy_config_path = os.path.join(out_path, "config.json")
|
||||
config_lines = open(config_file, "r", encoding="utf-8").readlines()
|
||||
# add extra information fields
|
||||
for key, value in new_fields.items():
|
||||
|
@ -73,7 +76,10 @@ def copy_model_files(c, config_file, out_path, new_fields):
|
|||
config_out_file.writelines(config_lines)
|
||||
config_out_file.close()
|
||||
# copy model stats file if available
|
||||
if c.audio['stats_path'] is not None:
|
||||
copy_stats_path = os.path.join(out_path, 'scale_stats.npy')
|
||||
if c.audio["stats_path"] is not None:
|
||||
copy_stats_path = os.path.join(out_path, "scale_stats.npy")
|
||||
if not os.path.exists(copy_stats_path):
|
||||
copyfile(c.audio['stats_path'], copy_stats_path, )
|
||||
copyfile(
|
||||
c.audio["stats_path"],
|
||||
copy_stats_path,
|
||||
)
|
||||
|
|
|
@ -22,12 +22,13 @@ class ModelManager(object):
|
|||
Args:
|
||||
models_file (str): path to .model.json
|
||||
"""
|
||||
|
||||
def __init__(self, models_file=None, output_prefix=None):
|
||||
super().__init__()
|
||||
if output_prefix is None:
|
||||
self.output_prefix = get_user_data_dir('tts')
|
||||
self.output_prefix = get_user_data_dir("tts")
|
||||
else:
|
||||
self.output_prefix = os.path.join(output_prefix, 'tts')
|
||||
self.output_prefix = os.path.join(output_prefix, "tts")
|
||||
self.url_prefix = "https://drive.google.com/uc?id="
|
||||
self.models_dict = None
|
||||
if models_file is not None:
|
||||
|
@ -72,7 +73,7 @@ class ModelManager(object):
|
|||
print(f" >: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
|
||||
else:
|
||||
print(f" >: {model_type}/{lang}/{dataset}/{model}")
|
||||
models_name_list.append(f'{model_type}/{lang}/{dataset}/{model}')
|
||||
models_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
|
||||
return models_name_list
|
||||
|
||||
def download_model(self, model_name):
|
||||
|
@ -104,25 +105,25 @@ class ModelManager(object):
|
|||
else:
|
||||
os.makedirs(output_path, exist_ok=True)
|
||||
print(f" > Downloading model to {output_path}")
|
||||
output_stats_path = os.path.join(output_path, 'scale_stats.npy')
|
||||
output_stats_path = os.path.join(output_path, "scale_stats.npy")
|
||||
# download files to the output path
|
||||
if self._check_dict_key(model_item, 'github_rls_url'):
|
||||
if self._check_dict_key(model_item, "github_rls_url"):
|
||||
# download from github release
|
||||
# TODO: pass output_path
|
||||
self._download_zip_file(model_item['github_rls_url'], output_path)
|
||||
self._download_zip_file(model_item["github_rls_url"], output_path)
|
||||
else:
|
||||
# download from gdrive
|
||||
self._download_gdrive_file(model_item['model_file'], output_model_path)
|
||||
self._download_gdrive_file(model_item['config_file'], output_config_path)
|
||||
if self._check_dict_key(model_item, 'stats_file'):
|
||||
self._download_gdrive_file(model_item['stats_file'], output_stats_path)
|
||||
self._download_gdrive_file(model_item["model_file"], output_model_path)
|
||||
self._download_gdrive_file(model_item["config_file"], output_config_path)
|
||||
if self._check_dict_key(model_item, "stats_file"):
|
||||
self._download_gdrive_file(model_item["stats_file"], output_stats_path)
|
||||
|
||||
# set the scale_path.npy file path in the model config.json
|
||||
if self._check_dict_key(model_item, 'stats_file') or os.path.exists(output_stats_path):
|
||||
if self._check_dict_key(model_item, "stats_file") or os.path.exists(output_stats_path):
|
||||
# set scale stats path in config.json
|
||||
config_path = output_config_path
|
||||
config = load_config(config_path)
|
||||
config["audio"]['stats_path'] = output_stats_path
|
||||
config["audio"]["stats_path"] = output_stats_path
|
||||
with open(config_path, "w") as jf:
|
||||
json.dump(config, jf)
|
||||
return output_model_path, output_config_path, model_item
|
||||
|
|
|
@ -6,7 +6,6 @@ from torch.optim.optimizer import Optimizer
|
|||
|
||||
|
||||
class RAdam(Optimizer):
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
|
||||
if lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
|
@ -20,13 +19,15 @@ class RAdam(Optimizer):
|
|||
self.degenerated_to_sgd = degenerated_to_sgd
|
||||
if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
|
||||
for param in params:
|
||||
if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
|
||||
param['buffer'] = [[None, None, None] for _ in range(10)]
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)])
|
||||
super(RAdam, self).__init__(params, defaults)
|
||||
if "betas" in param and (param["betas"][0] != betas[0] or param["betas"][1] != betas[1]):
|
||||
param["buffer"] = [[None, None, None] for _ in range(10)]
|
||||
defaults = dict(
|
||||
lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state): # pylint: disable=useless-super-delegation
|
||||
super(RAdam, self).__setstate__(state)
|
||||
super().__setstate__(state)
|
||||
|
||||
def step(self, closure=None):
|
||||
|
||||
|
@ -36,62 +37,70 @@ class RAdam(Optimizer):
|
|||
|
||||
for group in self.param_groups:
|
||||
|
||||
for p in group['params']:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data.float()
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('RAdam does not support sparse gradients')
|
||||
raise RuntimeError("RAdam does not support sparse gradients")
|
||||
|
||||
p_data_fp32 = p.data.float()
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
state['exp_avg'] = torch.zeros_like(p_data_fp32)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
|
||||
state["step"] = 0
|
||||
state["exp_avg"] = torch.zeros_like(p_data_fp32)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p_data_fp32)
|
||||
else:
|
||||
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
|
||||
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
|
||||
state["exp_avg"] = state["exp_avg"].type_as(p_data_fp32)
|
||||
state["exp_avg_sq"] = state["exp_avg_sq"].type_as(p_data_fp32)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = group["betas"]
|
||||
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
|
||||
state['step'] += 1
|
||||
buffered = group['buffer'][int(state['step'] % 10)]
|
||||
if state['step'] == buffered[0]:
|
||||
state["step"] += 1
|
||||
buffered = group["buffer"][int(state["step"] % 10)]
|
||||
if state["step"] == buffered[0]:
|
||||
N_sma, step_size = buffered[1], buffered[2]
|
||||
else:
|
||||
buffered[0] = state['step']
|
||||
beta2_t = beta2 ** state['step']
|
||||
buffered[0] = state["step"]
|
||||
beta2_t = beta2 ** state["step"]
|
||||
N_sma_max = 2 / (1 - beta2) - 1
|
||||
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
|
||||
N_sma = N_sma_max - 2 * state["step"] * beta2_t / (1 - beta2_t)
|
||||
buffered[1] = N_sma
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
|
||||
step_size = math.sqrt(
|
||||
(1 - beta2_t)
|
||||
* (N_sma - 4)
|
||||
/ (N_sma_max - 4)
|
||||
* (N_sma - 2)
|
||||
/ N_sma
|
||||
* N_sma_max
|
||||
/ (N_sma_max - 2)
|
||||
) / (1 - beta1 ** state["step"])
|
||||
elif self.degenerated_to_sgd:
|
||||
step_size = 1.0 / (1 - beta1 ** state['step'])
|
||||
step_size = 1.0 / (1 - beta1 ** state["step"])
|
||||
else:
|
||||
step_size = -1
|
||||
buffered[2] = step_size
|
||||
|
||||
# more conservative since it's an approximated value
|
||||
if N_sma >= 5:
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
|
||||
if group["weight_decay"] != 0:
|
||||
p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
|
||||
denom = exp_avg_sq.sqrt().add_(group["eps"])
|
||||
p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group["lr"])
|
||||
p.data.copy_(p_data_fp32)
|
||||
elif step_size > 0:
|
||||
if group['weight_decay'] != 0:
|
||||
p_data_fp32.add_(p_data_fp32, alpha=-group['weight_decay'] * group['lr'])
|
||||
p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
|
||||
if group["weight_decay"] != 0:
|
||||
p_data_fp32.add_(p_data_fp32, alpha=-group["weight_decay"] * group["lr"])
|
||||
p_data_fp32.add_(exp_avg, alpha=-step_size * group["lr"])
|
||||
p.data.copy_(p_data_fp32)
|
||||
|
||||
return loss
|
||||
|
|
|
@ -9,6 +9,7 @@ from TTS.utils.io import load_config
|
|||
from TTS.tts.utils.generic_utils import setup_model
|
||||
from TTS.tts.utils.speakers import load_speaker_mapping
|
||||
from TTS.vocoder.utils.generic_utils import setup_generator, interpolate_vocoder_input
|
||||
|
||||
# pylint: disable=unused-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from TTS.tts.utils.synthesis import synthesis, trim_silence
|
||||
|
@ -49,12 +50,11 @@ class Synthesizer(object):
|
|||
self.use_cuda = use_cuda
|
||||
if self.use_cuda:
|
||||
assert torch.cuda.is_available(), "CUDA is not availabe on this machine."
|
||||
self.load_tts(tts_checkpoint, tts_config,
|
||||
use_cuda)
|
||||
self.output_sample_rate = self.tts_config.audio['sample_rate']
|
||||
self.load_tts(tts_checkpoint, tts_config, use_cuda)
|
||||
self.output_sample_rate = self.tts_config.audio["sample_rate"]
|
||||
if vocoder_checkpoint:
|
||||
self.load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
|
||||
self.output_sample_rate = self.vocoder_config.audio['sample_rate']
|
||||
self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
|
||||
|
||||
@staticmethod
|
||||
def get_segmenter(lang):
|
||||
|
@ -69,16 +69,18 @@ class Synthesizer(object):
|
|||
self.num_speakers = 0
|
||||
# set external speaker embedding
|
||||
if self.tts_config.use_external_speaker_embedding_file:
|
||||
speaker_embedding = self.tts_speakers[list(self.tts_speakers.keys())[0]]['embedding']
|
||||
speaker_embedding = self.tts_speakers[list(self.tts_speakers.keys())[0]]["embedding"]
|
||||
self.speaker_embedding_dim = len(speaker_embedding)
|
||||
|
||||
def init_speaker(self, speaker_idx):
|
||||
# load speakers
|
||||
speaker_embedding = None
|
||||
if hasattr(self, 'tts_speakers') and speaker_idx is not None:
|
||||
assert speaker_idx < len(self.tts_speakers), f" [!] speaker_idx is out of the range. {speaker_idx} vs {len(self.tts_speakers)}"
|
||||
if hasattr(self, "tts_speakers") and speaker_idx is not None:
|
||||
assert speaker_idx < len(
|
||||
self.tts_speakers
|
||||
), f" [!] speaker_idx is out of the range. {speaker_idx} vs {len(self.tts_speakers)}"
|
||||
if self.tts_config.use_external_speaker_embedding_file:
|
||||
speaker_embedding = self.tts_speakers[speaker_idx]['embedding']
|
||||
speaker_embedding = self.tts_speakers[speaker_idx]["embedding"]
|
||||
return speaker_embedding
|
||||
|
||||
def load_tts(self, tts_checkpoint, tts_config, use_cuda):
|
||||
|
@ -90,7 +92,7 @@ class Synthesizer(object):
|
|||
self.use_phonemes = self.tts_config.use_phonemes
|
||||
self.ap = AudioProcessor(verbose=False, **self.tts_config.audio)
|
||||
|
||||
if 'characters' in self.tts_config.keys():
|
||||
if "characters" in self.tts_config.keys():
|
||||
symbols, phonemes = make_symbols(**self.tts_config.characters)
|
||||
|
||||
if self.use_phonemes:
|
||||
|
@ -105,7 +107,7 @@ class Synthesizer(object):
|
|||
|
||||
def load_vocoder(self, model_file, model_config, use_cuda):
|
||||
self.vocoder_config = load_config(model_config)
|
||||
self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config['audio'])
|
||||
self.vocoder_ap = AudioProcessor(verbose=False, **self.vocoder_config["audio"])
|
||||
self.vocoder_model = setup_generator(self.vocoder_config)
|
||||
self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True)
|
||||
if use_cuda:
|
||||
|
@ -141,7 +143,8 @@ class Synthesizer(object):
|
|||
False,
|
||||
self.tts_config.enable_eos_bos_chars,
|
||||
use_gl,
|
||||
speaker_embedding=speaker_embedding)
|
||||
speaker_embedding=speaker_embedding,
|
||||
)
|
||||
if not use_gl:
|
||||
# denormalize tts output based on tts audio config
|
||||
mel_postnet_spec = self.ap.denormalize(mel_postnet_spec.T).T
|
||||
|
@ -149,7 +152,7 @@ class Synthesizer(object):
|
|||
# renormalize spectrogram based on vocoder config
|
||||
vocoder_input = self.vocoder_ap.normalize(mel_postnet_spec.T)
|
||||
# compute scale factor for possible sample rate mismatch
|
||||
scale_factor = [1, self.vocoder_config['audio']['sample_rate'] / self.ap.sample_rate]
|
||||
scale_factor = [1, self.vocoder_config["audio"]["sample_rate"] / self.ap.sample_rate]
|
||||
if scale_factor[1] != 1:
|
||||
print(" > interpolating tts model output.")
|
||||
vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input)
|
||||
|
@ -172,7 +175,7 @@ class Synthesizer(object):
|
|||
|
||||
# compute stats
|
||||
process_time = time.time() - start_time
|
||||
audio_time = len(wavs) / self.tts_config.audio['sample_rate']
|
||||
audio_time = len(wavs) / self.tts_config.audio["sample_rate"]
|
||||
print(f" > Processing time: {process_time}")
|
||||
print(f" > Real-time factor: {process_time / audio_time}")
|
||||
return wavs
|
||||
|
|
|
@ -13,40 +13,28 @@ class TensorboardLogger(object):
|
|||
layer_num = 1
|
||||
for name, param in model.named_parameters():
|
||||
if param.numel() == 1:
|
||||
self.writer.add_scalar(
|
||||
"layer{}-{}/value".format(layer_num, name),
|
||||
param.max(), step)
|
||||
self.writer.add_scalar("layer{}-{}/value".format(layer_num, name), param.max(), step)
|
||||
else:
|
||||
self.writer.add_scalar(
|
||||
"layer{}-{}/max".format(layer_num, name),
|
||||
param.max(), step)
|
||||
self.writer.add_scalar(
|
||||
"layer{}-{}/min".format(layer_num, name),
|
||||
param.min(), step)
|
||||
self.writer.add_scalar(
|
||||
"layer{}-{}/mean".format(layer_num, name),
|
||||
param.mean(), step)
|
||||
self.writer.add_scalar(
|
||||
"layer{}-{}/std".format(layer_num, name),
|
||||
param.std(), step)
|
||||
self.writer.add_histogram(
|
||||
"layer{}-{}/param".format(layer_num, name), param, step)
|
||||
self.writer.add_histogram(
|
||||
"layer{}-{}/grad".format(layer_num, name), param.grad, step)
|
||||
self.writer.add_scalar("layer{}-{}/max".format(layer_num, name), param.max(), step)
|
||||
self.writer.add_scalar("layer{}-{}/min".format(layer_num, name), param.min(), step)
|
||||
self.writer.add_scalar("layer{}-{}/mean".format(layer_num, name), param.mean(), step)
|
||||
self.writer.add_scalar("layer{}-{}/std".format(layer_num, name), param.std(), step)
|
||||
self.writer.add_histogram("layer{}-{}/param".format(layer_num, name), param, step)
|
||||
self.writer.add_histogram("layer{}-{}/grad".format(layer_num, name), param.grad, step)
|
||||
layer_num += 1
|
||||
|
||||
def dict_to_tb_scalar(self, scope_name, stats, step):
|
||||
for key, value in stats.items():
|
||||
self.writer.add_scalar('{}/{}'.format(scope_name, key), value, step)
|
||||
self.writer.add_scalar("{}/{}".format(scope_name, key), value, step)
|
||||
|
||||
def dict_to_tb_figure(self, scope_name, figures, step):
|
||||
for key, value in figures.items():
|
||||
self.writer.add_figure('{}/{}'.format(scope_name, key), value, step)
|
||||
self.writer.add_figure("{}/{}".format(scope_name, key), value, step)
|
||||
|
||||
def dict_to_tb_audios(self, scope_name, audios, step, sample_rate):
|
||||
for key, value in audios.items():
|
||||
try:
|
||||
self.writer.add_audio('{}/{}'.format(scope_name, key), value, step, sample_rate=sample_rate)
|
||||
self.writer.add_audio("{}/{}".format(scope_name, key), value, step, sample_rate=sample_rate)
|
||||
except RuntimeError:
|
||||
traceback.print_exc()
|
||||
|
||||
|
|
|
@ -14,12 +14,13 @@ def setup_torch_training_env(cudnn_enable, cudnn_benchmark):
|
|||
|
||||
|
||||
def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
|
||||
r'''Check model gradient against unexpected jumps and failures'''
|
||||
r"""Check model gradient against unexpected jumps and failures"""
|
||||
skip_flag = False
|
||||
if ignore_stopnet:
|
||||
if not amp_opt_params:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
[param for name, param in model.named_parameters() if 'stopnet' not in name], grad_clip)
|
||||
[param for name, param in model.named_parameters() if "stopnet" not in name], grad_clip
|
||||
)
|
||||
else:
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(amp_opt_params, grad_clip)
|
||||
else:
|
||||
|
@ -41,11 +42,10 @@ def check_update(model, grad_clip, ignore_stopnet=False, amp_opt_params=None):
|
|||
|
||||
|
||||
def lr_decay(init_lr, global_step, warmup_steps):
|
||||
r'''from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py'''
|
||||
r"""from https://github.com/r9y9/tacotron_pytorch/blob/master/train.py"""
|
||||
warmup_steps = float(warmup_steps)
|
||||
step = global_step + 1.
|
||||
lr = init_lr * warmup_steps**0.5 * np.minimum(step * warmup_steps**-1.5,
|
||||
step**-0.5)
|
||||
step = global_step + 1.0
|
||||
lr = init_lr * warmup_steps ** 0.5 * np.minimum(step * warmup_steps ** -1.5, step ** -0.5)
|
||||
return lr
|
||||
|
||||
|
||||
|
@ -54,14 +54,14 @@ def adam_weight_decay(optimizer):
|
|||
Custom weight decay operation, not effecting grad values.
|
||||
"""
|
||||
for group in optimizer.param_groups:
|
||||
for param in group['params']:
|
||||
current_lr = group['lr']
|
||||
weight_decay = group['weight_decay']
|
||||
factor = -weight_decay * group['lr']
|
||||
param.data = param.data.add(param.data,
|
||||
alpha=factor)
|
||||
for param in group["params"]:
|
||||
current_lr = group["lr"]
|
||||
weight_decay = group["weight_decay"]
|
||||
factor = -weight_decay * group["lr"]
|
||||
param.data = param.data.add(param.data, alpha=factor)
|
||||
return optimizer, current_lr
|
||||
|
||||
|
||||
# pylint: disable=dangerous-default-value
|
||||
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn", "lstm", "gru", "embedding"}):
|
||||
"""
|
||||
|
@ -74,30 +74,23 @@ def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v", "rnn
|
|||
if not param.requires_grad:
|
||||
continue
|
||||
|
||||
if len(param.shape) == 1 or any([skip_name in name for skip_name in skip_list]):
|
||||
if len(param.shape) == 1 or any((skip_name in name for skip_name in skip_list)):
|
||||
no_decay.append(param)
|
||||
else:
|
||||
decay.append(param)
|
||||
return [{
|
||||
'params': no_decay,
|
||||
'weight_decay': 0.
|
||||
}, {
|
||||
'params': decay,
|
||||
'weight_decay': weight_decay
|
||||
}]
|
||||
return [{"params": no_decay, "weight_decay": 0.0}, {"params": decay, "weight_decay": weight_decay}]
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
|
||||
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
|
||||
self.warmup_steps = float(warmup_steps)
|
||||
super(NoamLR, self).__init__(optimizer, last_epoch)
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
step = max(self.last_epoch, 1)
|
||||
return [
|
||||
base_lr * self.warmup_steps**0.5 *
|
||||
min(step * self.warmup_steps**-1.5, step**-0.5)
|
||||
base_lr * self.warmup_steps ** 0.5 * min(step * self.warmup_steps ** -1.5, step ** -0.5)
|
||||
for base_lr in self.base_lrs
|
||||
]
|
||||
|
||||
|
|
|
@ -13,19 +13,22 @@ class GANDataset(Dataset):
|
|||
and converts them to acoustic features on the fly and returns
|
||||
random segments of (audio, feature) couples.
|
||||
"""
|
||||
def __init__(self,
|
||||
ap,
|
||||
items,
|
||||
seq_len,
|
||||
hop_len,
|
||||
pad_short,
|
||||
conv_pad=2,
|
||||
is_training=True,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
use_cache=False,
|
||||
verbose=False):
|
||||
super(GANDataset, self).__init__()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ap,
|
||||
items,
|
||||
seq_len,
|
||||
hop_len,
|
||||
pad_short,
|
||||
conv_pad=2,
|
||||
is_training=True,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
use_cache=False,
|
||||
verbose=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.ap = ap
|
||||
self.item_list = items
|
||||
self.compute_feat = not isinstance(items[0], (tuple, list))
|
||||
|
@ -57,14 +60,14 @@ class GANDataset(Dataset):
|
|||
|
||||
@staticmethod
|
||||
def find_wav_files(path):
|
||||
return glob.glob(os.path.join(path, '**', '*.wav'), recursive=True)
|
||||
return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.item_list)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
""" Return different items for Generator and Discriminator and
|
||||
cache acoustic features """
|
||||
"""Return different items for Generator and Discriminator and
|
||||
cache acoustic features"""
|
||||
if self.return_segments:
|
||||
idx2 = self.G_to_D_mappings[idx]
|
||||
item1 = self.load_item(idx)
|
||||
|
@ -76,13 +79,16 @@ class GANDataset(Dataset):
|
|||
def _pad_short_samples(self, audio, mel=None):
|
||||
"""Pad samples shorter than the output sequence length"""
|
||||
if len(audio) < self.seq_len:
|
||||
audio = np.pad(audio, (0, self.seq_len - len(audio)),
|
||||
mode='constant',
|
||||
constant_values=0.0)
|
||||
audio = np.pad(audio, (0, self.seq_len - len(audio)), mode="constant", constant_values=0.0)
|
||||
|
||||
if mel is not None and mel.shape[1] < self.feat_frame_len:
|
||||
pad_value = self.ap.melspectrogram(np.zeros([self.ap.win_length]))[:, 0]
|
||||
mel = np.pad(mel, ([0, 0], [0, self.feat_frame_len - mel.shape[1]]), mode='constant', constant_values=pad_value.mean())
|
||||
mel = np.pad(
|
||||
mel,
|
||||
([0, 0], [0, self.feat_frame_len - mel.shape[1]]),
|
||||
mode="constant",
|
||||
constant_values=pad_value.mean(),
|
||||
)
|
||||
return audio, mel
|
||||
|
||||
def shuffle_mapping(self):
|
||||
|
@ -111,12 +117,14 @@ class GANDataset(Dataset):
|
|||
else:
|
||||
audio = self.ap.load_wav(wavpath)
|
||||
mel = np.load(feat_path)
|
||||
audio, mel= self._pad_short_samples(audio, mel)
|
||||
audio, mel = self._pad_short_samples(audio, mel)
|
||||
|
||||
# correct the audio length wrt padding applied in stft
|
||||
audio = np.pad(audio, (0, self.hop_len), mode="edge")
|
||||
audio = audio[:mel.shape[-1] * self.hop_len]
|
||||
assert mel.shape[-1] * self.hop_len == audio.shape[-1], f' [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}'
|
||||
audio = audio[: mel.shape[-1] * self.hop_len]
|
||||
assert (
|
||||
mel.shape[-1] * self.hop_len == audio.shape[-1]
|
||||
), f" [!] {mel.shape[-1] * self.hop_len} vs {audio.shape[-1]}"
|
||||
|
||||
audio = torch.from_numpy(audio).float().unsqueeze(0)
|
||||
mel = torch.from_numpy(mel).float().squeeze(0)
|
||||
|
@ -128,8 +136,7 @@ class GANDataset(Dataset):
|
|||
mel = mel[:, mel_start:mel_end]
|
||||
|
||||
audio_start = mel_start * self.hop_len
|
||||
audio = audio[:, audio_start:audio_start +
|
||||
self.seq_len]
|
||||
audio = audio[:, audio_start : audio_start + self.seq_len]
|
||||
|
||||
if self.use_noise_augment and self.is_training and self.return_segments:
|
||||
audio = audio + (1 / 32768) * torch.randn_like(audio)
|
||||
|
|
|
@ -18,11 +18,7 @@ def preprocess_wav_files(out_path, config, ap):
|
|||
mel = ap.melspectrogram(y)
|
||||
np.save(mel_path, mel)
|
||||
if isinstance(config.mode, int):
|
||||
quant = (
|
||||
ap.mulaw_encode(y, qc=config.mode)
|
||||
if config.mulaw
|
||||
else ap.quantize(y, bits=config.mode)
|
||||
)
|
||||
quant = ap.mulaw_encode(y, qc=config.mode) if config.mulaw else ap.quantize(y, bits=config.mode)
|
||||
np.save(quant_path, quant)
|
||||
|
||||
|
||||
|
|
|
@ -13,18 +13,21 @@ class WaveGradDataset(Dataset):
|
|||
and converts them to acoustic features on the fly and returns
|
||||
random segments of (audio, feature) couples.
|
||||
"""
|
||||
def __init__(self,
|
||||
ap,
|
||||
items,
|
||||
seq_len,
|
||||
hop_len,
|
||||
pad_short,
|
||||
conv_pad=2,
|
||||
is_training=True,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
use_cache=False,
|
||||
verbose=False):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ap,
|
||||
items,
|
||||
seq_len,
|
||||
hop_len,
|
||||
pad_short,
|
||||
conv_pad=2,
|
||||
is_training=True,
|
||||
return_segments=True,
|
||||
use_noise_augment=False,
|
||||
use_cache=False,
|
||||
verbose=False,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.ap = ap
|
||||
|
@ -54,7 +57,7 @@ class WaveGradDataset(Dataset):
|
|||
|
||||
@staticmethod
|
||||
def find_wav_files(path):
|
||||
return glob.glob(os.path.join(path, '**', '*.wav'), recursive=True)
|
||||
return glob.glob(os.path.join(path, "**", "*.wav"), recursive=True)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.item_list)
|
||||
|
@ -86,13 +89,16 @@ class WaveGradDataset(Dataset):
|
|||
if self.return_segments:
|
||||
# correct audio length wrt segment length
|
||||
if audio.shape[-1] < self.seq_len + self.pad_short:
|
||||
audio = np.pad(audio, (0, self.seq_len + self.pad_short - len(audio)), \
|
||||
mode='constant', constant_values=0.0)
|
||||
assert audio.shape[-1] >= self.seq_len + self.pad_short, f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}"
|
||||
audio = np.pad(
|
||||
audio, (0, self.seq_len + self.pad_short - len(audio)), mode="constant", constant_values=0.0
|
||||
)
|
||||
assert (
|
||||
audio.shape[-1] >= self.seq_len + self.pad_short
|
||||
), f"{audio.shape[-1]} vs {self.seq_len + self.pad_short}"
|
||||
|
||||
# correct the audio length wrt hop length
|
||||
p = (audio.shape[-1] // self.hop_len + 1) * self.hop_len - audio.shape[-1]
|
||||
audio = np.pad(audio, (0, p), mode='constant', constant_values=0.0)
|
||||
audio = np.pad(audio, (0, p), mode="constant", constant_values=0.0)
|
||||
|
||||
if self.use_cache:
|
||||
self.cache[idx] = audio
|
||||
|
@ -126,7 +132,7 @@ class WaveGradDataset(Dataset):
|
|||
for idx, b in enumerate(batch):
|
||||
mel = b[0]
|
||||
audio = b[1]
|
||||
mels[idx, :, :mel.shape[1]] = mel
|
||||
audios[idx, :audio.shape[0]] = audio
|
||||
mels[idx, :, : mel.shape[1]] = mel
|
||||
audios[idx, : audio.shape[0]] = audio
|
||||
|
||||
return mels, audios
|
||||
|
|
|
@ -9,19 +9,20 @@ class WaveRNNDataset(Dataset):
|
|||
and converts them to acoustic features on the fly.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
ap,
|
||||
items,
|
||||
seq_len,
|
||||
hop_len,
|
||||
pad,
|
||||
mode,
|
||||
mulaw,
|
||||
is_training=True,
|
||||
verbose=False,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
ap,
|
||||
items,
|
||||
seq_len,
|
||||
hop_len,
|
||||
pad,
|
||||
mode,
|
||||
mulaw,
|
||||
is_training=True,
|
||||
verbose=False,
|
||||
):
|
||||
|
||||
super(WaveRNNDataset, self).__init__()
|
||||
super().__init__()
|
||||
self.ap = ap
|
||||
self.compute_feat = not isinstance(items[0], (tuple, list))
|
||||
self.item_list = items
|
||||
|
@ -61,8 +62,9 @@ class WaveRNNDataset(Dataset):
|
|||
if self.mode in ["gauss", "mold"]:
|
||||
x_input = audio
|
||||
elif isinstance(self.mode, int):
|
||||
x_input = (self.ap.mulaw_encode(audio, qc=self.mode)
|
||||
if self.mulaw else self.ap.quantize(audio, bits=self.mode))
|
||||
x_input = (
|
||||
self.ap.mulaw_encode(audio, qc=self.mode) if self.mulaw else self.ap.quantize(audio, bits=self.mode)
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Unknown dataset mode - ", self.mode)
|
||||
|
||||
|
@ -71,7 +73,7 @@ class WaveRNNDataset(Dataset):
|
|||
wavpath, feat_path = self.item_list[index]
|
||||
mel = np.load(feat_path.replace("/quant/", "/mel/"))
|
||||
|
||||
if mel.shape[-1] < self.mel_len + 2 * self.pad:
|
||||
if mel.shape[-1] < self.mel_len + 2 * self.pad:
|
||||
print(" [!] Instance is too short! : {}".format(wavpath))
|
||||
self.item_list[index] = self.item_list[index + 1]
|
||||
feat_path = self.item_list[index]
|
||||
|
@ -87,22 +89,14 @@ class WaveRNNDataset(Dataset):
|
|||
|
||||
def collate(self, batch):
|
||||
mel_win = self.seq_len // self.hop_len + 2 * self.pad
|
||||
max_offsets = [x[0].shape[-1] -
|
||||
(mel_win + 2 * self.pad) for x in batch]
|
||||
max_offsets = [x[0].shape[-1] - (mel_win + 2 * self.pad) for x in batch]
|
||||
|
||||
mel_offsets = [np.random.randint(0, offset) for offset in max_offsets]
|
||||
sig_offsets = [(offset + self.pad) *
|
||||
self.hop_len for offset in mel_offsets]
|
||||
sig_offsets = [(offset + self.pad) * self.hop_len for offset in mel_offsets]
|
||||
|
||||
mels = [
|
||||
x[0][:, mel_offsets[i]: mel_offsets[i] + mel_win]
|
||||
for i, x in enumerate(batch)
|
||||
]
|
||||
mels = [x[0][:, mel_offsets[i] : mel_offsets[i] + mel_win] for i, x in enumerate(batch)]
|
||||
|
||||
coarse = [
|
||||
x[1][sig_offsets[i]: sig_offsets[i] + self.seq_len + 1]
|
||||
for i, x in enumerate(batch)
|
||||
]
|
||||
coarse = [x[1][sig_offsets[i] : sig_offsets[i] + self.seq_len + 1] for i, x in enumerate(batch)]
|
||||
|
||||
mels = np.stack(mels).astype(np.float32)
|
||||
if self.mode in ["gauss", "mold"]:
|
||||
|
@ -112,8 +106,7 @@ class WaveRNNDataset(Dataset):
|
|||
elif isinstance(self.mode, int):
|
||||
coarse = np.stack(coarse).astype(np.int64)
|
||||
coarse = torch.LongTensor(coarse)
|
||||
x_input = (2 * coarse[:, : self.seq_len].float() /
|
||||
(2 ** self.mode - 1.0) - 1.0)
|
||||
x_input = 2 * coarse[:, : self.seq_len].float() / (2 ** self.mode - 1.0) - 1.0
|
||||
y_coarse = coarse[:, 1:]
|
||||
mels = torch.FloatTensor(mels)
|
||||
return x_input, mels, y_coarse
|
||||
|
|
|
@ -6,18 +6,21 @@ from torch.nn import functional as F
|
|||
|
||||
class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
||||
"""TODO: Merge this with audio.py"""
|
||||
def __init__(self,
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
pad_wav=False,
|
||||
window='hann_window',
|
||||
sample_rate=None,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
n_mels=80,
|
||||
use_mel=False):
|
||||
super(TorchSTFT, self).__init__()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
pad_wav=False,
|
||||
window="hann_window",
|
||||
sample_rate=None,
|
||||
mel_fmin=0,
|
||||
mel_fmax=None,
|
||||
n_mels=80,
|
||||
use_mel=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
|
@ -27,8 +30,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
self.mel_fmax = mel_fmax
|
||||
self.n_mels = n_mels
|
||||
self.use_mel = use_mel
|
||||
self.window = nn.Parameter(getattr(torch, window)(win_length),
|
||||
requires_grad=False)
|
||||
self.window = nn.Parameter(getattr(torch, window)(win_length), requires_grad=False)
|
||||
self.mel_basis = None
|
||||
if use_mel:
|
||||
self._build_mel_basis()
|
||||
|
@ -49,7 +51,7 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
x = x.unsqueeze(1)
|
||||
if self.pad_wav:
|
||||
padding = int((self.n_fft - self.hop_length) / 2)
|
||||
x = torch.nn.functional.pad(x, (padding, padding), mode='reflect')
|
||||
x = torch.nn.functional.pad(x, (padding, padding), mode="reflect")
|
||||
# B x D x T x 2
|
||||
o = torch.stft(
|
||||
x.squeeze(1),
|
||||
|
@ -61,35 +63,34 @@ class TorchSTFT(nn.Module): # pylint: disable=abstract-method
|
|||
pad_mode="reflect", # compatible with audio.py
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
return_complex=False)
|
||||
return_complex=False,
|
||||
)
|
||||
M = o[:, :, :, 0]
|
||||
P = o[:, :, :, 1]
|
||||
S = torch.sqrt(torch.clamp(M**2 + P**2, min=1e-8))
|
||||
S = torch.sqrt(torch.clamp(M ** 2 + P ** 2, min=1e-8))
|
||||
if self.use_mel:
|
||||
S = torch.matmul(self.mel_basis.to(x), S)
|
||||
return S
|
||||
|
||||
def _build_mel_basis(self):
|
||||
mel_basis = librosa.filters.mel(self.sample_rate,
|
||||
self.n_fft,
|
||||
n_mels=self.n_mels,
|
||||
fmin=self.mel_fmin,
|
||||
fmax=self.mel_fmax)
|
||||
mel_basis = librosa.filters.mel(
|
||||
self.sample_rate, self.n_fft, n_mels=self.n_mels, fmin=self.mel_fmin, fmax=self.mel_fmax
|
||||
)
|
||||
self.mel_basis = torch.from_numpy(mel_basis).float()
|
||||
|
||||
|
||||
|
||||
#################################
|
||||
# GENERATOR LOSSES
|
||||
#################################
|
||||
|
||||
|
||||
class STFTLoss(nn.Module):
|
||||
""" STFT loss. Input generate and real waveforms are converted
|
||||
"""STFT loss. Input generate and real waveforms are converted
|
||||
to spectrograms compared with L1 and Spectral convergence losses.
|
||||
It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf"""
|
||||
|
||||
def __init__(self, n_fft, hop_length, win_length):
|
||||
super(STFTLoss, self).__init__()
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
|
@ -104,15 +105,14 @@ class STFTLoss(nn.Module):
|
|||
loss_sc = torch.norm(y_M - y_hat_M, p="fro") / torch.norm(y_M, p="fro")
|
||||
return loss_mag, loss_sc
|
||||
|
||||
|
||||
class MultiScaleSTFTLoss(torch.nn.Module):
|
||||
""" Multi-scale STFT loss. Input generate and real waveforms are converted
|
||||
"""Multi-scale STFT loss. Input generate and real waveforms are converted
|
||||
to spectrograms compared with L1 and Spectral convergence losses.
|
||||
It is from ParallelWaveGAN paper https://arxiv.org/pdf/1910.11480.pdf"""
|
||||
def __init__(self,
|
||||
n_ffts=(1024, 2048, 512),
|
||||
hop_lengths=(120, 240, 50),
|
||||
win_lengths=(600, 1200, 240)):
|
||||
super(MultiScaleSTFTLoss, self).__init__()
|
||||
|
||||
def __init__(self, n_ffts=(1024, 2048, 512), hop_lengths=(120, 240, 50), win_lengths=(600, 1200, 240)):
|
||||
super().__init__()
|
||||
self.loss_funcs = torch.nn.ModuleList()
|
||||
for n_fft, hop_length, win_length in zip(n_ffts, hop_lengths, win_lengths):
|
||||
self.loss_funcs.append(STFTLoss(n_fft, hop_length, win_length))
|
||||
|
@ -129,19 +129,25 @@ class MultiScaleSTFTLoss(torch.nn.Module):
|
|||
loss_mag /= N
|
||||
return loss_mag, loss_sc
|
||||
|
||||
|
||||
class L1SpecLoss(nn.Module):
|
||||
""" L1 Loss over Spectrograms as described in HiFiGAN paper https://arxiv.org/pdf/2010.05646.pdf"""
|
||||
def __init__(self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True):
|
||||
|
||||
def __init__(
|
||||
self, sample_rate, n_fft, hop_length, win_length, mel_fmin=None, mel_fmax=None, n_mels=None, use_mel=True
|
||||
):
|
||||
super().__init__()
|
||||
self.use_mel = use_mel
|
||||
self.stft = TorchSTFT(n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
sample_rate=sample_rate,
|
||||
mel_fmin=mel_fmin,
|
||||
mel_fmax=mel_fmax,
|
||||
n_mels=n_mels,
|
||||
use_mel=use_mel)
|
||||
self.stft = TorchSTFT(
|
||||
n_fft,
|
||||
hop_length,
|
||||
win_length,
|
||||
sample_rate=sample_rate,
|
||||
mel_fmin=mel_fmin,
|
||||
mel_fmax=mel_fmax,
|
||||
n_mels=n_mels,
|
||||
use_mel=use_mel,
|
||||
)
|
||||
|
||||
def forward(self, y_hat, y):
|
||||
y_hat_M = self.stft(y_hat)
|
||||
|
@ -150,9 +156,11 @@ class L1SpecLoss(nn.Module):
|
|||
loss_mag = F.l1_loss(torch.log(y_M), torch.log(y_hat_M))
|
||||
return loss_mag
|
||||
|
||||
|
||||
class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
|
||||
""" Multiscale STFT loss for multi band model outputs.
|
||||
"""Multiscale STFT loss for multi band model outputs.
|
||||
From MultiBand-MelGAN paper https://arxiv.org/abs/2005.05106"""
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, y_hat, y):
|
||||
y_hat = y_hat.view(-1, 1, y_hat.shape[2])
|
||||
|
@ -162,6 +170,7 @@ class MultiScaleSubbandSTFTLoss(MultiScaleSTFTLoss):
|
|||
|
||||
class MSEGLoss(nn.Module):
|
||||
""" Mean Squared Generator Loss """
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_real):
|
||||
loss_fake = F.mse_loss(score_real, score_real.new_ones(score_real.shape))
|
||||
|
@ -170,10 +179,11 @@ class MSEGLoss(nn.Module):
|
|||
|
||||
class HingeGLoss(nn.Module):
|
||||
""" Hinge Discriminator Loss """
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_real):
|
||||
# TODO: this might be wrong
|
||||
loss_fake = torch.mean(F.relu(1. - score_real))
|
||||
loss_fake = torch.mean(F.relu(1.0 - score_real))
|
||||
return loss_fake
|
||||
|
||||
|
||||
|
@ -184,8 +194,11 @@ class HingeGLoss(nn.Module):
|
|||
|
||||
class MSEDLoss(nn.Module):
|
||||
""" Mean Squared Discriminator Loss """
|
||||
def __init__(self,):
|
||||
super(MSEDLoss, self).__init__()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
super().__init__()
|
||||
self.loss_func = nn.MSELoss()
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
|
@ -198,17 +211,20 @@ class MSEDLoss(nn.Module):
|
|||
|
||||
class HingeDLoss(nn.Module):
|
||||
""" Hinge Discriminator Loss """
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
def forward(self, score_fake, score_real):
|
||||
loss_real = torch.mean(F.relu(1. - score_real))
|
||||
loss_fake = torch.mean(F.relu(1. + score_fake))
|
||||
loss_real = torch.mean(F.relu(1.0 - score_real))
|
||||
loss_fake = torch.mean(F.relu(1.0 + score_fake))
|
||||
loss_d = loss_real + loss_fake
|
||||
return loss_d, loss_real, loss_fake
|
||||
|
||||
|
||||
class MelganFeatureLoss(nn.Module):
|
||||
def __init__(self,):
|
||||
super(MelganFeatureLoss, self).__init__()
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
super().__init__()
|
||||
self.loss_func = nn.L1Loss()
|
||||
|
||||
# pylint: disable=no-self-use
|
||||
|
@ -229,8 +245,8 @@ class MelganFeatureLoss(nn.Module):
|
|||
|
||||
|
||||
def _apply_G_adv_loss(scores_fake, loss_func):
|
||||
""" Compute G adversarial loss function
|
||||
and normalize values """
|
||||
"""Compute G adversarial loss function
|
||||
and normalize values"""
|
||||
adv_loss = 0
|
||||
if isinstance(scores_fake, list):
|
||||
for score_fake in scores_fake:
|
||||
|
@ -279,24 +295,26 @@ class GeneratorLoss(nn.Module):
|
|||
Args:
|
||||
C (AttrDict): model configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, C):
|
||||
super().__init__()
|
||||
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
|
||||
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
||||
assert not (
|
||||
C.use_mse_gan_loss and C.use_hinge_gan_loss
|
||||
), " [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
||||
|
||||
self.use_stft_loss = C.use_stft_loss if 'use_stft_loss' in C else False
|
||||
self.use_subband_stft_loss = C.use_subband_stft_loss if 'use_subband_stft_loss' in C else False
|
||||
self.use_mse_gan_loss = C.use_mse_gan_loss if 'use_mse_gan_loss' in C else False
|
||||
self.use_hinge_gan_loss = C.use_hinge_gan_loss if 'use_hinge_gan_loss' in C else False
|
||||
self.use_feat_match_loss = C.use_feat_match_loss if 'use_feat_match_loss' in C else False
|
||||
self.use_l1_spec_loss = C.use_l1_spec_loss if 'use_l1_spec_loss' in C else False
|
||||
self.use_stft_loss = C.use_stft_loss if "use_stft_loss" in C else False
|
||||
self.use_subband_stft_loss = C.use_subband_stft_loss if "use_subband_stft_loss" in C else False
|
||||
self.use_mse_gan_loss = C.use_mse_gan_loss if "use_mse_gan_loss" in C else False
|
||||
self.use_hinge_gan_loss = C.use_hinge_gan_loss if "use_hinge_gan_loss" in C else False
|
||||
self.use_feat_match_loss = C.use_feat_match_loss if "use_feat_match_loss" in C else False
|
||||
self.use_l1_spec_loss = C.use_l1_spec_loss if "use_l1_spec_loss" in C else False
|
||||
|
||||
self.stft_loss_weight = C.stft_loss_weight if 'stft_loss_weight' in C else 0.0
|
||||
self.subband_stft_loss_weight = C.subband_stft_loss_weight if 'subband_stft_loss_weight' in C else 0.0
|
||||
self.mse_gan_loss_weight = C.mse_G_loss_weight if 'mse_G_loss_weight' in C else 0.0
|
||||
self.hinge_gan_loss_weight = C.hinge_G_loss_weight if 'hinde_G_loss_weight' in C else 0.0
|
||||
self.feat_match_loss_weight = C.feat_match_loss_weight if 'feat_match_loss_weight' in C else 0.0
|
||||
self.l1_spec_loss_weight = C.l1_spec_loss_weight if 'l1_spec_loss_weight' in C else 0.0
|
||||
self.stft_loss_weight = C.stft_loss_weight if "stft_loss_weight" in C else 0.0
|
||||
self.subband_stft_loss_weight = C.subband_stft_loss_weight if "subband_stft_loss_weight" in C else 0.0
|
||||
self.mse_gan_loss_weight = C.mse_G_loss_weight if "mse_G_loss_weight" in C else 0.0
|
||||
self.hinge_gan_loss_weight = C.hinge_G_loss_weight if "hinde_G_loss_weight" in C else 0.0
|
||||
self.feat_match_loss_weight = C.feat_match_loss_weight if "feat_match_loss_weight" in C else 0.0
|
||||
self.l1_spec_loss_weight = C.l1_spec_loss_weight if "l1_spec_loss_weight" in C else 0.0
|
||||
|
||||
if C.use_stft_loss:
|
||||
self.stft_loss = MultiScaleSTFTLoss(**C.stft_loss_params)
|
||||
|
@ -309,63 +327,67 @@ class GeneratorLoss(nn.Module):
|
|||
if C.use_feat_match_loss:
|
||||
self.feat_match_loss = MelganFeatureLoss()
|
||||
if C.use_l1_spec_loss:
|
||||
assert C.audio['sample_rate'] == C.l1_spec_loss_params['sample_rate']
|
||||
assert C.audio["sample_rate"] == C.l1_spec_loss_params["sample_rate"]
|
||||
self.l1_spec_loss = L1SpecLoss(**C.l1_spec_loss_params)
|
||||
|
||||
def forward(self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None):
|
||||
def forward(
|
||||
self, y_hat=None, y=None, scores_fake=None, feats_fake=None, feats_real=None, y_hat_sub=None, y_sub=None
|
||||
):
|
||||
gen_loss = 0
|
||||
adv_loss = 0
|
||||
return_dict = {}
|
||||
|
||||
# STFT Loss
|
||||
if self.use_stft_loss:
|
||||
stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, :y.size(2)].squeeze(1), y.squeeze(1))
|
||||
return_dict['G_stft_loss_mg'] = stft_loss_mg
|
||||
return_dict['G_stft_loss_sc'] = stft_loss_sc
|
||||
stft_loss_mg, stft_loss_sc = self.stft_loss(y_hat[:, :, : y.size(2)].squeeze(1), y.squeeze(1))
|
||||
return_dict["G_stft_loss_mg"] = stft_loss_mg
|
||||
return_dict["G_stft_loss_sc"] = stft_loss_sc
|
||||
gen_loss = gen_loss + self.stft_loss_weight * (stft_loss_mg + stft_loss_sc)
|
||||
|
||||
# L1 Spec loss
|
||||
if self.use_l1_spec_loss:
|
||||
l1_spec_loss = self.l1_spec_loss(y_hat, y)
|
||||
return_dict['G_l1_spec_loss'] = l1_spec_loss
|
||||
return_dict["G_l1_spec_loss"] = l1_spec_loss
|
||||
gen_loss = gen_loss + self.l1_spec_loss_weight * l1_spec_loss
|
||||
|
||||
# subband STFT Loss
|
||||
if self.use_subband_stft_loss:
|
||||
subband_stft_loss_mg, subband_stft_loss_sc = self.subband_stft_loss(y_hat_sub, y_sub)
|
||||
return_dict['G_subband_stft_loss_mg'] = subband_stft_loss_mg
|
||||
return_dict['G_subband_stft_loss_sc'] = subband_stft_loss_sc
|
||||
return_dict["G_subband_stft_loss_mg"] = subband_stft_loss_mg
|
||||
return_dict["G_subband_stft_loss_sc"] = subband_stft_loss_sc
|
||||
gen_loss = gen_loss + self.subband_stft_loss_weight * (subband_stft_loss_mg + subband_stft_loss_sc)
|
||||
|
||||
# multiscale MSE adversarial loss
|
||||
if self.use_mse_gan_loss and scores_fake is not None:
|
||||
mse_fake_loss = _apply_G_adv_loss(scores_fake, self.mse_loss)
|
||||
return_dict['G_mse_fake_loss'] = mse_fake_loss
|
||||
return_dict["G_mse_fake_loss"] = mse_fake_loss
|
||||
adv_loss = adv_loss + self.mse_gan_loss_weight * mse_fake_loss
|
||||
|
||||
# multiscale Hinge adversarial loss
|
||||
if self.use_hinge_gan_loss and not scores_fake is not None:
|
||||
hinge_fake_loss = _apply_G_adv_loss(scores_fake, self.hinge_loss)
|
||||
return_dict['G_hinge_fake_loss'] = hinge_fake_loss
|
||||
return_dict["G_hinge_fake_loss"] = hinge_fake_loss
|
||||
adv_loss = adv_loss + self.hinge_gan_loss_weight * hinge_fake_loss
|
||||
|
||||
# Feature Matching Loss
|
||||
if self.use_feat_match_loss and not feats_fake is None:
|
||||
feat_match_loss = self.feat_match_loss(feats_fake, feats_real)
|
||||
return_dict['G_feat_match_loss'] = feat_match_loss
|
||||
return_dict["G_feat_match_loss"] = feat_match_loss
|
||||
adv_loss = adv_loss + self.feat_match_loss_weight * feat_match_loss
|
||||
return_dict['G_loss'] = gen_loss + adv_loss
|
||||
return_dict['G_gen_loss'] = gen_loss
|
||||
return_dict['G_adv_loss'] = adv_loss
|
||||
return_dict["G_loss"] = gen_loss + adv_loss
|
||||
return_dict["G_gen_loss"] = gen_loss
|
||||
return_dict["G_adv_loss"] = adv_loss
|
||||
return return_dict
|
||||
|
||||
|
||||
class DiscriminatorLoss(nn.Module):
|
||||
"""Like ```GeneratorLoss```"""
|
||||
|
||||
def __init__(self, C):
|
||||
super().__init__()
|
||||
assert not(C.use_mse_gan_loss and C.use_hinge_gan_loss),\
|
||||
" [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
||||
assert not (
|
||||
C.use_mse_gan_loss and C.use_hinge_gan_loss
|
||||
), " [!] Cannot use HingeGANLoss and MSEGANLoss together."
|
||||
|
||||
self.use_mse_gan_loss = C.use_mse_gan_loss
|
||||
self.use_hinge_gan_loss = C.use_hinge_gan_loss
|
||||
|
@ -381,23 +403,21 @@ class DiscriminatorLoss(nn.Module):
|
|||
|
||||
if self.use_mse_gan_loss:
|
||||
mse_D_loss, mse_D_real_loss, mse_D_fake_loss = _apply_D_loss(
|
||||
scores_fake=scores_fake,
|
||||
scores_real=scores_real,
|
||||
loss_func=self.mse_loss)
|
||||
return_dict['D_mse_gan_loss'] = mse_D_loss
|
||||
return_dict['D_mse_gan_real_loss'] = mse_D_real_loss
|
||||
return_dict['D_mse_gan_fake_loss'] = mse_D_fake_loss
|
||||
scores_fake=scores_fake, scores_real=scores_real, loss_func=self.mse_loss
|
||||
)
|
||||
return_dict["D_mse_gan_loss"] = mse_D_loss
|
||||
return_dict["D_mse_gan_real_loss"] = mse_D_real_loss
|
||||
return_dict["D_mse_gan_fake_loss"] = mse_D_fake_loss
|
||||
loss += mse_D_loss
|
||||
|
||||
if self.use_hinge_gan_loss:
|
||||
hinge_D_loss, hinge_D_real_loss, hinge_D_fake_loss = _apply_D_loss(
|
||||
scores_fake=scores_fake,
|
||||
scores_real=scores_real,
|
||||
loss_func=self.hinge_loss)
|
||||
return_dict['D_hinge_gan_loss'] = hinge_D_loss
|
||||
return_dict['D_hinge_gan_real_loss'] = hinge_D_real_loss
|
||||
return_dict['D_hinge_gan_fake_loss'] = hinge_D_fake_loss
|
||||
scores_fake=scores_fake, scores_real=scores_real, loss_func=self.hinge_loss
|
||||
)
|
||||
return_dict["D_hinge_gan_loss"] = hinge_D_loss
|
||||
return_dict["D_hinge_gan_real_loss"] = hinge_D_real_loss
|
||||
return_dict["D_hinge_gan_fake_loss"] = hinge_D_fake_loss
|
||||
loss += hinge_D_loss
|
||||
|
||||
return_dict['D_loss'] = loss
|
||||
return_dict["D_loss"] = loss
|
||||
return return_dict
|
||||
|
|
|
@ -4,7 +4,7 @@ from torch.nn.utils import weight_norm
|
|||
|
||||
class ResidualStack(nn.Module):
|
||||
def __init__(self, channels, num_res_blocks, kernel_size):
|
||||
super(ResidualStack, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
assert (kernel_size - 1) % 2 == 0, " [!] kernel_size has to be odd."
|
||||
base_padding = (kernel_size - 1) // 2
|
||||
|
@ -12,26 +12,23 @@ class ResidualStack(nn.Module):
|
|||
self.blocks = nn.ModuleList()
|
||||
for idx in range(num_res_blocks):
|
||||
layer_kernel_size = kernel_size
|
||||
layer_dilation = layer_kernel_size**idx
|
||||
layer_dilation = layer_kernel_size ** idx
|
||||
layer_padding = base_padding * layer_dilation
|
||||
self.blocks += [nn.Sequential(
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(layer_padding),
|
||||
weight_norm(
|
||||
nn.Conv1d(channels,
|
||||
channels,
|
||||
kernel_size=kernel_size,
|
||||
dilation=layer_dilation,
|
||||
bias=True)),
|
||||
nn.LeakyReLU(0.2),
|
||||
weight_norm(
|
||||
nn.Conv1d(channels, channels, kernel_size=1, bias=True)),
|
||||
)]
|
||||
self.blocks += [
|
||||
nn.Sequential(
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.ReflectionPad1d(layer_padding),
|
||||
weight_norm(
|
||||
nn.Conv1d(channels, channels, kernel_size=kernel_size, dilation=layer_dilation, bias=True)
|
||||
),
|
||||
nn.LeakyReLU(0.2),
|
||||
weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)),
|
||||
)
|
||||
]
|
||||
|
||||
self.shortcuts = nn.ModuleList([
|
||||
weight_norm(nn.Conv1d(channels, channels, kernel_size=1,
|
||||
bias=True)) for i in range(num_res_blocks)
|
||||
])
|
||||
self.shortcuts = nn.ModuleList(
|
||||
[weight_norm(nn.Conv1d(channels, channels, kernel_size=1, bias=True)) for i in range(num_res_blocks)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for block, shortcut in zip(self.blocks, self.shortcuts):
|
||||
|
|
|
@ -4,54 +4,44 @@ from torch.nn import functional as F
|
|||
|
||||
class ResidualBlock(torch.nn.Module):
|
||||
"""Residual block module in WaveNet."""
|
||||
def __init__(self,
|
||||
kernel_size=3,
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
aux_channels=80,
|
||||
dropout=0.0,
|
||||
dilation=1,
|
||||
bias=True,
|
||||
use_causal_conv=False):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel_size=3,
|
||||
res_channels=64,
|
||||
gate_channels=128,
|
||||
skip_channels=64,
|
||||
aux_channels=80,
|
||||
dropout=0.0,
|
||||
dilation=1,
|
||||
bias=True,
|
||||
use_causal_conv=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.dropout = dropout
|
||||
# no future time stamps available
|
||||
if use_causal_conv:
|
||||
padding = (kernel_size - 1) * dilation
|
||||
else:
|
||||
assert (kernel_size -
|
||||
1) % 2 == 0, "Not support even number kernel size."
|
||||
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
||||
padding = (kernel_size - 1) // 2 * dilation
|
||||
self.use_causal_conv = use_causal_conv
|
||||
|
||||
# dilation conv
|
||||
self.conv = torch.nn.Conv1d(res_channels,
|
||||
gate_channels,
|
||||
kernel_size,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias)
|
||||
self.conv = torch.nn.Conv1d(
|
||||
res_channels, gate_channels, kernel_size, padding=padding, dilation=dilation, bias=bias
|
||||
)
|
||||
|
||||
# local conditioning
|
||||
if aux_channels > 0:
|
||||
self.conv1x1_aux = torch.nn.Conv1d(aux_channels,
|
||||
gate_channels,
|
||||
1,
|
||||
bias=False)
|
||||
self.conv1x1_aux = torch.nn.Conv1d(aux_channels, gate_channels, 1, bias=False)
|
||||
else:
|
||||
self.conv1x1_aux = None
|
||||
|
||||
# conv output is split into two groups
|
||||
gate_out_channels = gate_channels // 2
|
||||
self.conv1x1_out = torch.nn.Conv1d(gate_out_channels,
|
||||
res_channels,
|
||||
1,
|
||||
bias=bias)
|
||||
self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels,
|
||||
skip_channels,
|
||||
1,
|
||||
bias=bias)
|
||||
self.conv1x1_out = torch.nn.Conv1d(gate_out_channels, res_channels, 1, bias=bias)
|
||||
self.conv1x1_skip = torch.nn.Conv1d(gate_out_channels, skip_channels, 1, bias=bias)
|
||||
|
||||
def forward(self, x, c):
|
||||
"""
|
||||
|
@ -63,7 +53,7 @@ class ResidualBlock(torch.nn.Module):
|
|||
x = self.conv(x)
|
||||
|
||||
# remove future time steps if use_causal_conv conv
|
||||
x = x[:, :, :residual.size(-1)] if self.use_causal_conv else x
|
||||
x = x[:, :, : residual.size(-1)] if self.use_causal_conv else x
|
||||
|
||||
# split into two part for gated activation
|
||||
splitdim = 1
|
||||
|
@ -82,6 +72,6 @@ class ResidualBlock(torch.nn.Module):
|
|||
s = self.conv1x1_skip(x)
|
||||
|
||||
# for residual connection
|
||||
x = (self.conv1x1_out(x) + residual) * (0.5**2)
|
||||
x = (self.conv1x1_out(x) + residual) * (0.5 ** 2)
|
||||
|
||||
return x, s
|
||||
|
|
|
@ -9,21 +9,21 @@ from scipy import signal as sig
|
|||
# https://github.com/kan-bayashi/ParallelWaveGAN/tree/master/parallel_wavegan
|
||||
class PQMF(torch.nn.Module):
|
||||
def __init__(self, N=4, taps=62, cutoff=0.15, beta=9.0):
|
||||
super(PQMF, self).__init__()
|
||||
super().__init__()
|
||||
|
||||
self.N = N
|
||||
self.taps = taps
|
||||
self.cutoff = cutoff
|
||||
self.beta = beta
|
||||
|
||||
QMF = sig.firwin(taps + 1, cutoff, window=('kaiser', beta))
|
||||
QMF = sig.firwin(taps + 1, cutoff, window=("kaiser", beta))
|
||||
H = np.zeros((N, len(QMF)))
|
||||
G = np.zeros((N, len(QMF)))
|
||||
for k in range(N):
|
||||
constant_factor = (2 * k + 1) * (np.pi /
|
||||
(2 * N)) * (np.arange(taps + 1) -
|
||||
((taps - 1) / 2)) # TODO: (taps - 1) -> taps
|
||||
phase = (-1)**k * np.pi / 4
|
||||
constant_factor = (
|
||||
(2 * k + 1) * (np.pi / (2 * N)) * (np.arange(taps + 1) - ((taps - 1) / 2))
|
||||
) # TODO: (taps - 1) -> taps
|
||||
phase = (-1) ** k * np.pi / 4
|
||||
H[k] = 2 * QMF * np.cos(constant_factor + phase)
|
||||
|
||||
G[k] = 2 * QMF * np.cos(constant_factor - phase)
|
||||
|
@ -49,8 +49,6 @@ class PQMF(torch.nn.Module):
|
|||
return F.conv1d(x, self.H, padding=self.taps // 2, stride=self.N)
|
||||
|
||||
def synthesis(self, x):
|
||||
x = F.conv_transpose1d(x,
|
||||
self.updown_filter * self.N,
|
||||
stride=self.N)
|
||||
x = F.conv_transpose1d(x, self.updown_filter * self.N, stride=self.N)
|
||||
x = F.conv1d(x, self.G, padding=self.taps // 2)
|
||||
return x
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue