mirror of https://github.com/coqui-ai/TTS.git
add SS files
parent
4b1065f6db
commit
cf869e8922
|
@ -0,0 +1,619 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import numpy as np
|
||||
from random import randrange
|
||||
|
||||
import torch
|
||||
# DISTRIBUTED
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP_th
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from TTS.tts.datasets.preprocess import load_meta_data
|
||||
from TTS.tts.datasets.TTSDataset import MyDataset
|
||||
from TTS.tts.layers.losses import SpeedySpeechLoss
|
||||
from TTS.tts.utils.generic_utils import check_config_tts, setup_model
|
||||
from TTS.tts.utils.io import save_best_model, save_checkpoint
|
||||
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||
from TTS.tts.utils.speakers import parse_speakers
|
||||
from TTS.tts.utils.synthesis import synthesis
|
||||
from TTS.tts.utils.text.symbols import make_symbols, phonemes, symbols
|
||||
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||
from TTS.utils.audio import AudioProcessor
|
||||
from TTS.utils.console_logger import ConsoleLogger
|
||||
from TTS.utils.distribute import init_distributed, reduce_tensor
|
||||
from TTS.utils.generic_utils import (KeepAverage, count_parameters,
|
||||
create_experiment_folder, get_git_branch,
|
||||
remove_experiment_folder, set_init_dict)
|
||||
from TTS.utils.io import copy_config_file, load_config
|
||||
from TTS.utils.radam import RAdam
|
||||
from TTS.utils.tensorboard_logger import TensorboardLogger
|
||||
from TTS.utils.training import NoamLR, setup_torch_training_env
|
||||
|
||||
use_cuda, num_gpus = setup_torch_training_env(True, False)
|
||||
|
||||
|
||||
def setup_loader(ap, r, is_val=False, verbose=False):
|
||||
if is_val and not c.run_eval:
|
||||
loader = None
|
||||
else:
|
||||
dataset = MyDataset(
|
||||
r,
|
||||
c.text_cleaner,
|
||||
compute_linear_spec=False,
|
||||
meta_data=meta_data_eval if is_val else meta_data_train,
|
||||
ap=ap,
|
||||
tp=c.characters if 'characters' in c.keys() else None,
|
||||
add_blank=c['add_blank'] if 'add_blank' in c.keys() else False,
|
||||
batch_group_size=0 if is_val else c.batch_group_size *
|
||||
c.batch_size,
|
||||
min_seq_len=c.min_seq_len,
|
||||
max_seq_len=c.max_seq_len,
|
||||
phoneme_cache_path=c.phoneme_cache_path,
|
||||
use_phonemes=c.use_phonemes,
|
||||
phoneme_language=c.phoneme_language,
|
||||
enable_eos_bos=c.enable_eos_bos_chars,
|
||||
use_noise_augment=not is_val,
|
||||
verbose=verbose,
|
||||
speaker_mapping=speaker_mapping if c.use_speaker_embedding and c.use_external_speaker_embedding_file else None)
|
||||
|
||||
if c.use_phonemes and c.compute_input_seq_cache:
|
||||
# precompute phonemes to have a better estimate of sequence lengths.
|
||||
dataset.compute_input_seq(c.num_loader_workers)
|
||||
dataset.sort_items()
|
||||
|
||||
sampler = DistributedSampler(dataset) if num_gpus > 1 else None
|
||||
loader = DataLoader(
|
||||
dataset,
|
||||
batch_size=c.eval_batch_size if is_val else c.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=False,
|
||||
sampler=sampler,
|
||||
num_workers=c.num_val_loader_workers
|
||||
if is_val else c.num_loader_workers,
|
||||
pin_memory=False)
|
||||
return loader
|
||||
|
||||
|
||||
def format_data(data):
|
||||
# setup input data
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
speaker_names = data[2]
|
||||
mel_input = data[4].permute(0, 2, 1) # B x D x T
|
||||
mel_lengths = data[5]
|
||||
item_idx = data[7]
|
||||
attn_mask = data[9]
|
||||
avg_text_length = torch.mean(text_lengths.float())
|
||||
avg_spec_length = torch.mean(mel_lengths.float())
|
||||
|
||||
if c.use_speaker_embedding:
|
||||
if c.use_external_speaker_embedding_file:
|
||||
# return precomputed embedding vector
|
||||
speaker_c = data[8]
|
||||
else:
|
||||
# return speaker_id to be used by an embedding layer
|
||||
speaker_c = [
|
||||
speaker_mapping[speaker_name] for speaker_name in speaker_names
|
||||
]
|
||||
speaker_c = torch.LongTensor(speaker_c)
|
||||
else:
|
||||
speaker_c = None
|
||||
# compute durations from attention mask
|
||||
durations = torch.zeros(attn_mask.shape[0], attn_mask.shape[2])
|
||||
for idx, am in enumerate(attn_mask):
|
||||
# compute raw durations
|
||||
c_idxs = am[:, :text_lengths[idx], :mel_lengths[idx]].max(1)[1]
|
||||
# c_idxs, counts = torch.unique_consecutive(c_idxs, return_counts=True)
|
||||
c_idxs, counts = torch.unique(c_idxs, return_counts=True)
|
||||
dur = torch.ones([text_lengths[idx]]).to(counts.dtype)
|
||||
dur[c_idxs] = counts
|
||||
# smooth the durations and set any 0 duration to 1
|
||||
# by cutting off from the largest duration indeces.
|
||||
extra_frames = dur.sum() - mel_lengths[idx]
|
||||
largest_idxs = torch.argsort(-dur)[:extra_frames]
|
||||
dur[largest_idxs] -= 1
|
||||
assert dur.sum() == mel_lengths[idx], f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}"
|
||||
durations[idx, :text_lengths[idx]] = dur
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
text_input = text_input.cuda(non_blocking=True)
|
||||
text_lengths = text_lengths.cuda(non_blocking=True)
|
||||
mel_input = mel_input.cuda(non_blocking=True)
|
||||
mel_lengths = mel_lengths.cuda(non_blocking=True)
|
||||
if speaker_c is not None:
|
||||
speaker_c = speaker_c.cuda(non_blocking=True)
|
||||
attn_mask = attn_mask.cuda(non_blocking=True)
|
||||
durations = durations.cuda(non_blocking=True)
|
||||
return text_input, text_lengths, mel_input, mel_lengths, speaker_c,\
|
||||
avg_text_length, avg_spec_length, attn_mask, durations, item_idx
|
||||
|
||||
|
||||
def train(data_loader, model, criterion, optimizer, scheduler,
|
||||
ap, global_step, epoch):
|
||||
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
if use_cuda:
|
||||
batch_n_iter = int(
|
||||
len(data_loader.dataset) / (c.batch_size * num_gpus))
|
||||
else:
|
||||
batch_n_iter = int(len(data_loader.dataset) / c.batch_size)
|
||||
end_time = time.time()
|
||||
c_logger.print_train_start()
|
||||
scaler = torch.cuda.amp.GradScaler() if c.mixed_precision else None
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\
|
||||
avg_text_length, avg_spec_length, attn_mask, dur_target, item_idx = format_data(data)
|
||||
|
||||
loader_time = time.time() - end_time
|
||||
|
||||
global_step += 1
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward pass model
|
||||
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
||||
decoder_output, dur_output, alignments = model.forward(
|
||||
text_input, text_lengths, mel_lengths, dur_target, g=speaker_c)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths)
|
||||
|
||||
# backward pass with loss scaling
|
||||
if c.mixed_precision:
|
||||
scaler.scale(loss_dict['loss']).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||
c.grad_clip)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss_dict['loss'].backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(),
|
||||
c.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
# setup lr
|
||||
if c.noam_schedule:
|
||||
scheduler.step()
|
||||
|
||||
# current_lr
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
# compute alignment error (the lower the better )
|
||||
align_error = 1 - alignment_diagonal_score(alignments, binary=True)
|
||||
loss_dict['align_error'] = align_error
|
||||
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, num_gpus)
|
||||
loss_dict['loss_ssim'] = reduce_tensor(loss_dict['loss_ssim'].data, num_gpus)
|
||||
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
|
||||
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
for key, value in loss_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict_new[key] = value
|
||||
else:
|
||||
loss_dict_new[key] = value.item()
|
||||
loss_dict = loss_dict_new
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values['avg_' + key] = value
|
||||
update_train_values['avg_loader_time'] = loader_time
|
||||
update_train_values['avg_step_time'] = step_time
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
# print training progress
|
||||
if global_step % c.print_step == 0:
|
||||
log_dict = {
|
||||
|
||||
"avg_spec_length": [avg_spec_length, 1], # value, precision
|
||||
"avg_text_length": [avg_text_length, 1],
|
||||
"step_time": [step_time, 4],
|
||||
"loader_time": [loader_time, 2],
|
||||
"current_lr": current_lr,
|
||||
}
|
||||
c_logger.print_train_step(batch_n_iter, num_iter, global_step,
|
||||
log_dict, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# Plot Training Iter Stats
|
||||
# reduce TB load
|
||||
if global_step % c.tb_plot_step == 0:
|
||||
iter_stats = {
|
||||
"lr": current_lr,
|
||||
"grad_norm": grad_norm,
|
||||
"step_time": step_time
|
||||
}
|
||||
iter_stats.update(loss_dict)
|
||||
tb_logger.tb_train_iter_stats(global_step, iter_stats)
|
||||
|
||||
if global_step % c.save_step == 0:
|
||||
if c.checkpoint:
|
||||
# save model
|
||||
save_checkpoint(model, optimizer, global_step, epoch, 1, OUT_PATH,
|
||||
model_loss=loss_dict['loss'])
|
||||
|
||||
# wait all kernels to be completed
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Diagnostic visualizations
|
||||
idx = np.random.randint(mel_targets.shape[0])
|
||||
pred_spec = decoder_output[idx].detach().data.cpu().numpy().T
|
||||
gt_spec = mel_targets[idx].data.cpu().numpy().T
|
||||
align_img = alignments[idx].data.cpu()
|
||||
|
||||
figures = {
|
||||
"prediction": plot_spectrogram(pred_spec, ap),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap),
|
||||
"alignment": plot_alignment(align_img),
|
||||
}
|
||||
|
||||
tb_logger.tb_train_figures(global_step, figures)
|
||||
|
||||
# Sample audio
|
||||
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
tb_logger.tb_train_audios(global_step,
|
||||
{'TrainAudio': train_audio},
|
||||
c.audio["sample_rate"])
|
||||
end_time = time.time()
|
||||
|
||||
# print epoch stats
|
||||
c_logger.print_train_epoch_end(global_step, epoch, epoch_time, keep_avg)
|
||||
|
||||
# Plot Epoch Stats
|
||||
if args.rank == 0:
|
||||
epoch_stats = {"epoch_time": epoch_time}
|
||||
epoch_stats.update(keep_avg.avg_values)
|
||||
tb_logger.tb_train_epoch_stats(global_step, epoch_stats)
|
||||
if c.tb_model_param_stats:
|
||||
tb_logger.tb_model_weights(model, global_step)
|
||||
return keep_avg.avg_values, global_step
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(data_loader, model, criterion, ap, global_step, epoch):
|
||||
model.eval()
|
||||
epoch_time = 0
|
||||
keep_avg = KeepAverage()
|
||||
c_logger.print_eval_start()
|
||||
if data_loader is not None:
|
||||
for num_iter, data in enumerate(data_loader):
|
||||
start_time = time.time()
|
||||
|
||||
# format data
|
||||
text_input, text_lengths, mel_targets, mel_lengths, speaker_c,\
|
||||
avg_text_length, avg_spec_length, attn_mask, dur_target, item_idx = format_data(data)
|
||||
|
||||
# forward pass model
|
||||
with torch.cuda.amp.autocast(enabled=c.mixed_precision):
|
||||
decoder_output, dur_output, alignments = model.forward(
|
||||
text_input, text_lengths, mel_lengths, dur_target, g=speaker_c)
|
||||
|
||||
# compute loss
|
||||
loss_dict = criterion(decoder_output, mel_targets, mel_lengths, dur_output, torch.log(1 + dur_target), text_lengths)
|
||||
|
||||
# step time
|
||||
step_time = time.time() - start_time
|
||||
epoch_time += step_time
|
||||
|
||||
# compute alignment score
|
||||
align_error = 1 - alignment_diagonal_score(alignments)
|
||||
loss_dict['align_error'] = align_error
|
||||
|
||||
# aggregate losses from processes
|
||||
if num_gpus > 1:
|
||||
loss_dict['loss_l1'] = reduce_tensor(loss_dict['loss_l1'].data, num_gpus)
|
||||
loss_dict['loss_ssim'] = reduce_tensor(loss_dict['loss_ssim'].data, num_gpus)
|
||||
loss_dict['loss_dur'] = reduce_tensor(loss_dict['loss_dur'].data, num_gpus)
|
||||
loss_dict['loss'] = reduce_tensor(loss_dict['loss'] .data, num_gpus)
|
||||
|
||||
# detach loss values
|
||||
loss_dict_new = dict()
|
||||
for key, value in loss_dict.items():
|
||||
if isinstance(value, (int, float)):
|
||||
loss_dict_new[key] = value
|
||||
else:
|
||||
loss_dict_new[key] = value.item()
|
||||
loss_dict = loss_dict_new
|
||||
|
||||
# update avg stats
|
||||
update_train_values = dict()
|
||||
for key, value in loss_dict.items():
|
||||
update_train_values['avg_' + key] = value
|
||||
keep_avg.update_values(update_train_values)
|
||||
|
||||
if c.print_eval:
|
||||
c_logger.print_eval_step(num_iter, loss_dict, keep_avg.avg_values)
|
||||
|
||||
if args.rank == 0:
|
||||
# Diagnostic visualizations
|
||||
idx = np.random.randint(mel_targets.shape[0])
|
||||
pred_spec = decoder_output[idx].detach().data.cpu().numpy().T
|
||||
gt_spec = mel_targets[idx].data.cpu().numpy().T
|
||||
align_img = alignments[idx].data.cpu()
|
||||
|
||||
eval_figures = {
|
||||
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
||||
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||
"alignment": plot_alignment(align_img, output_fig=False)
|
||||
}
|
||||
|
||||
# Sample audio
|
||||
eval_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||
tb_logger.tb_eval_audios(global_step, {"ValAudio": eval_audio},
|
||||
c.audio["sample_rate"])
|
||||
|
||||
# Plot Validation Stats
|
||||
tb_logger.tb_eval_stats(global_step, keep_avg.avg_values)
|
||||
tb_logger.tb_eval_figures(global_step, eval_figures)
|
||||
|
||||
if args.rank == 0 and epoch >= c.test_delay_epochs:
|
||||
if c.test_sentences_file is None:
|
||||
test_sentences = [
|
||||
"It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
|
||||
"Be a voice, not an echo.",
|
||||
"I'm sorry Dave. I'm afraid I can't do that.",
|
||||
"This cake is great. It's so delicious and moist.",
|
||||
"Prior to November 22, 1963."
|
||||
]
|
||||
else:
|
||||
with open(c.test_sentences_file, "r") as f:
|
||||
test_sentences = [s.strip() for s in f.readlines()]
|
||||
|
||||
# test sentences
|
||||
test_audios = {}
|
||||
test_figures = {}
|
||||
print(" | > Synthesizing test sentences")
|
||||
if c.use_speaker_embedding:
|
||||
if c.use_external_speaker_embedding_file:
|
||||
speaker_embedding = speaker_mapping[list(speaker_mapping.keys())[randrange(len(speaker_mapping)-1)]]['embedding']
|
||||
speaker_id = None
|
||||
else:
|
||||
speaker_id = 0
|
||||
speaker_embedding = None
|
||||
else:
|
||||
speaker_id = None
|
||||
speaker_embedding = None
|
||||
|
||||
style_wav = c.get("style_wav_for_test")
|
||||
for idx, test_sentence in enumerate(test_sentences):
|
||||
try:
|
||||
wav, alignment, _, postnet_output, _, _ = synthesis(
|
||||
model,
|
||||
test_sentence,
|
||||
c,
|
||||
use_cuda,
|
||||
ap,
|
||||
speaker_id=speaker_id,
|
||||
speaker_embedding=speaker_embedding,
|
||||
style_wav=style_wav,
|
||||
truncated=False,
|
||||
enable_eos_bos_chars=c.enable_eos_bos_chars, #pylint: disable=unused-argument
|
||||
use_griffin_lim=True,
|
||||
do_trim_silence=False)
|
||||
|
||||
file_path = os.path.join(AUDIO_PATH, str(global_step))
|
||||
os.makedirs(file_path, exist_ok=True)
|
||||
file_path = os.path.join(file_path,
|
||||
"TestSentence_{}.wav".format(idx))
|
||||
ap.save_wav(wav, file_path)
|
||||
test_audios['{}-audio'.format(idx)] = wav
|
||||
test_figures['{}-prediction'.format(idx)] = plot_spectrogram(
|
||||
postnet_output, ap)
|
||||
test_figures['{}-alignment'.format(idx)] = plot_alignment(
|
||||
alignment)
|
||||
except: #pylint: disable=bare-except
|
||||
print(" !! Error creating Test Sentence -", idx)
|
||||
traceback.print_exc()
|
||||
tb_logger.tb_test_audios(global_step, test_audios,
|
||||
c.audio['sample_rate'])
|
||||
tb_logger.tb_test_figures(global_step, test_figures)
|
||||
return keep_avg.avg_values
|
||||
|
||||
|
||||
# FIXME: move args definition/parsing inside of main?
|
||||
def main(args): # pylint: disable=redefined-outer-name
|
||||
# pylint: disable=global-variable-undefined
|
||||
global meta_data_train, meta_data_eval, symbols, phonemes, speaker_mapping
|
||||
# Audio processor
|
||||
ap = AudioProcessor(**c.audio)
|
||||
if 'characters' in c.keys():
|
||||
symbols, phonemes = make_symbols(**c.characters)
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
init_distributed(args.rank, num_gpus, args.group_id,
|
||||
c.distributed["backend"], c.distributed["url"])
|
||||
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
|
||||
|
||||
# load data instances
|
||||
meta_data_train, meta_data_eval = load_meta_data(c.datasets, eval_split=True)
|
||||
|
||||
# set the portion of the data used for training if set in config.json
|
||||
if 'train_portion' in c.keys():
|
||||
meta_data_train = meta_data_train[:int(len(meta_data_train) * c.train_portion)]
|
||||
if 'eval_portion' in c.keys():
|
||||
meta_data_eval = meta_data_eval[:int(len(meta_data_eval) * c.eval_portion)]
|
||||
|
||||
# parse speakers
|
||||
num_speakers, speaker_embedding_dim, speaker_mapping = parse_speakers(c, args, meta_data_train, OUT_PATH)
|
||||
|
||||
# setup model
|
||||
model = setup_model(num_chars, num_speakers, c, speaker_embedding_dim=speaker_embedding_dim)
|
||||
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0, betas=(0.9, 0.98), eps=1e-9)
|
||||
criterion = SpeedySpeechLoss(c)
|
||||
|
||||
if args.restore_path:
|
||||
checkpoint = torch.load(args.restore_path, map_location='cpu')
|
||||
try:
|
||||
# TODO: fix optimizer init, model.cuda() needs to be called before
|
||||
# optimizer restore
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
if c.reinit_layers:
|
||||
raise RuntimeError
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
except: #pylint: disable=bare-except
|
||||
print(" > Partial model initialization.")
|
||||
model_dict = model.state_dict()
|
||||
model_dict = set_init_dict(model_dict, checkpoint['model'], c)
|
||||
model.load_state_dict(model_dict)
|
||||
del model_dict
|
||||
|
||||
for group in optimizer.param_groups:
|
||||
group['initial_lr'] = c.lr
|
||||
print(" > Model restored from step %d" % checkpoint['step'],
|
||||
flush=True)
|
||||
args.restore_step = checkpoint['step']
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
criterion.cuda()
|
||||
|
||||
# DISTRUBUTED
|
||||
if num_gpus > 1:
|
||||
model = DDP_th(model, device_ids=[args.rank])
|
||||
|
||||
if c.noam_schedule:
|
||||
scheduler = NoamLR(optimizer,
|
||||
warmup_steps=c.warmup_steps,
|
||||
last_epoch=args.restore_step - 1)
|
||||
else:
|
||||
scheduler = None
|
||||
|
||||
num_params = count_parameters(model)
|
||||
print("\n > Model has {} parameters".format(num_params), flush=True)
|
||||
|
||||
if 'best_loss' not in locals():
|
||||
best_loss = float('inf')
|
||||
|
||||
# define dataloaders
|
||||
train_loader = setup_loader(ap, 1, is_val=False, verbose=True)
|
||||
eval_loader = setup_loader(ap, 1, is_val=True, verbose=True)
|
||||
|
||||
global_step = args.restore_step
|
||||
for epoch in range(0, c.epochs):
|
||||
c_logger.print_epoch_start(epoch, c.epochs)
|
||||
train_avg_loss_dict, global_step = train(train_loader, model, criterion, optimizer,
|
||||
scheduler, ap, global_step,
|
||||
epoch)
|
||||
eval_avg_loss_dict = evaluate(eval_loader , model, criterion, ap, global_step, epoch)
|
||||
c_logger.print_epoch_end(epoch, eval_avg_loss_dict)
|
||||
target_loss = train_avg_loss_dict['avg_loss']
|
||||
if c.run_eval:
|
||||
target_loss = eval_avg_loss_dict['avg_loss']
|
||||
best_loss = save_best_model(target_loss, best_loss, model, optimizer, global_step, epoch, c.r,
|
||||
OUT_PATH)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--continue_path',
|
||||
type=str,
|
||||
help='Training output folder to continue training. Use to continue a training. If it is used, "config_path" is ignored.',
|
||||
default='',
|
||||
required='--config_path' not in sys.argv)
|
||||
parser.add_argument(
|
||||
'--restore_path',
|
||||
type=str,
|
||||
help='Model file to be restored. Use to finetune a model.',
|
||||
default='')
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
type=str,
|
||||
help='Path to config file for training.',
|
||||
required='--continue_path' not in sys.argv
|
||||
)
|
||||
parser.add_argument('--debug',
|
||||
type=bool,
|
||||
default=False,
|
||||
help='Do not verify commit integrity to run training.')
|
||||
|
||||
# DISTRUBUTED
|
||||
parser.add_argument(
|
||||
'--rank',
|
||||
type=int,
|
||||
default=0,
|
||||
help='DISTRIBUTED: process rank for distributed training.')
|
||||
parser.add_argument('--group_id',
|
||||
type=str,
|
||||
default="",
|
||||
help='DISTRIBUTED: process group id.')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.continue_path != '':
|
||||
args.output_path = args.continue_path
|
||||
args.config_path = os.path.join(args.continue_path, 'config.json')
|
||||
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
|
||||
latest_model_file = max(list_of_files, key=os.path.getctime)
|
||||
args.restore_path = latest_model_file
|
||||
print(f" > Training continues for {args.restore_path}")
|
||||
|
||||
# setup output paths and read configs
|
||||
c = load_config(args.config_path)
|
||||
# check_config(c)
|
||||
check_config_tts(c)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
if c.mixed_precision:
|
||||
print(" > Mixed precision enabled.")
|
||||
|
||||
OUT_PATH = args.continue_path
|
||||
if args.continue_path == '':
|
||||
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
|
||||
|
||||
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
||||
|
||||
c_logger = ConsoleLogger()
|
||||
|
||||
if args.rank == 0:
|
||||
os.makedirs(AUDIO_PATH, exist_ok=True)
|
||||
new_fields = {}
|
||||
if args.restore_path:
|
||||
new_fields["restore_path"] = args.restore_path
|
||||
new_fields["github_branch"] = get_git_branch()
|
||||
copy_config_file(args.config_path,
|
||||
os.path.join(OUT_PATH, 'config.json'), new_fields)
|
||||
os.chmod(AUDIO_PATH, 0o775)
|
||||
os.chmod(OUT_PATH, 0o775)
|
||||
|
||||
LOG_DIR = OUT_PATH
|
||||
tb_logger = TensorboardLogger(LOG_DIR, model_name='TTS')
|
||||
|
||||
# write model desc to tensorboard
|
||||
tb_logger.tb_add_text('model-description', c['run_description'], 0)
|
||||
|
||||
try:
|
||||
main(args)
|
||||
except KeyboardInterrupt:
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
try:
|
||||
sys.exit(0)
|
||||
except SystemExit:
|
||||
os._exit(0) # pylint: disable=protected-access
|
||||
except Exception: # pylint: disable=broad-except
|
||||
remove_experiment_folder(OUT_PATH)
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
|
@ -0,0 +1,149 @@
|
|||
{
|
||||
"model": "speedy_speech",
|
||||
"run_name": "speedy-speech-ljspeech",
|
||||
"run_description": "speedy-speech model for LJSpeech dataset.",
|
||||
|
||||
// AUDIO PARAMETERS
|
||||
"audio":{
|
||||
// stft parameters
|
||||
"fft_size": 1024, // number of stft frequency levels. Size of the linear spectogram frame.
|
||||
"win_length": 1024, // stft window length in ms.
|
||||
"hop_length": 256, // stft window hop-lengh in ms.
|
||||
"frame_length_ms": null, // stft window length in ms.If null, 'win_length' is used.
|
||||
"frame_shift_ms": null, // stft window hop-lengh in ms. If null, 'hop_length' is used.
|
||||
|
||||
// Audio processing parameters
|
||||
"sample_rate": 22050, // DATASET-RELATED: wav sample-rate.
|
||||
"preemphasis": 0.0, // pre-emphasis to reduce spec noise and make it more structured. If 0.0, no -pre-emphasis.
|
||||
"ref_level_db": 20, // reference level db, theoretically 20db is the sound of air.
|
||||
|
||||
// Silence trimming
|
||||
"do_trim_silence": true,// enable trimming of slience of audio as you load it. LJspeech (true), TWEB (false), Nancy (true)
|
||||
"trim_db": 60, // threshold for timming silence. Set this according to your dataset.
|
||||
|
||||
// Griffin-Lim
|
||||
"power": 1.5, // value to sharpen wav signals after GL algorithm.
|
||||
"griffin_lim_iters": 60,// #griffin-lim iterations. 30-60 is a good range. Larger the value, slower the generation.
|
||||
|
||||
// MelSpectrogram parameters
|
||||
"num_mels": 80, // size of the mel spec frame.
|
||||
"mel_fmin": 50.0, // minimum freq level for mel-spec. ~50 for male and ~95 for female voices. Tune for dataset!!
|
||||
"mel_fmax": 7600.0, // maximum freq level for mel-spec. Tune for dataset!!
|
||||
"spec_gain": 1,
|
||||
|
||||
// Normalization parameters
|
||||
"signal_norm": true, // normalize spec values. Mean-Var normalization if 'stats_path' is defined otherwise range normalization defined by the other params.
|
||||
"min_level_db": -100, // lower bound for normalization
|
||||
"symmetric_norm": true, // move normalization to range [-1, 1]
|
||||
"max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm]
|
||||
"clip_norm": true, // clip normalized values into the range.
|
||||
"stats_path": "/home/erogol/Data/LJSpeech-1.1/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored
|
||||
},
|
||||
|
||||
// VOCABULARY PARAMETERS
|
||||
// if custom character set is not defined,
|
||||
// default set in symbols.py is used
|
||||
// "characters":{
|
||||
// "pad": "_",
|
||||
// "eos": "&",
|
||||
// "bos": "*",
|
||||
// "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZÇÃÀÁÂÊÉÍÓÔÕÚÛabcdefghijklmnopqrstuvwxyzçãàáâêéíóôõúû!(),-.:;? ",
|
||||
// "punctuations":"!'(),-.:;? ",
|
||||
// "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ'̃' "
|
||||
// },
|
||||
|
||||
"add_blank": false, // if true add a new token after each token of the sentence. This increases the size of the input sequence, but has considerably improved the prosody of the GlowTTS model.
|
||||
|
||||
// DISTRIBUTED TRAINING
|
||||
"distributed":{
|
||||
"backend": "nccl",
|
||||
"url": "tcp:\/\/localhost:54321"
|
||||
},
|
||||
|
||||
"reinit_layers": [], // give a list of layer names to restore from the given checkpoint. If not defined, it reloads all heuristically matching layers.
|
||||
|
||||
// MODEL PARAMETERS
|
||||
"positional_encoding": true,
|
||||
"encoder_type": "residual_conv_bn",
|
||||
"encoder_params":{
|
||||
"kernel_size": 4,
|
||||
"dilations": [1, 2, 4, 1, 2, 4, 1, 2, 4, 1, 2, 4, 1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 13
|
||||
},
|
||||
"decoder_residual_conv_bn_params":{
|
||||
"kernel_size": 4,
|
||||
"dilations": [1, 2, 4, 8, 1, 2, 4, 8, 1, 2, 4, 8, 1, 2, 4, 8, 1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 17
|
||||
},
|
||||
|
||||
// TRAINING
|
||||
"batch_size":64, // Batch size for training. Lower values than 32 might cause hard to learn attention. It is overwritten by 'gradual_training'.
|
||||
"eval_batch_size":32,
|
||||
"r": 1, // Number of decoder frames to predict per iteration. Set the initial values if gradual training is enabled.
|
||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||
|
||||
// LOSS PARAMETERS
|
||||
"ssim_alpha": 1,
|
||||
"l1_alpha": 1,
|
||||
"huber_alpha": 1,
|
||||
|
||||
// VALIDATION
|
||||
"run_eval": true,
|
||||
"test_delay_epochs": -1, //Until attention is aligned, testing only wastes computation time.
|
||||
"test_sentences_file": null, // set a file to load sentences to be used for testing. If it is null then we use default english sentences.
|
||||
|
||||
// OPTIMIZER
|
||||
"noam_schedule": true, // use noam warmup and lr schedule.
|
||||
"grad_clip": 1.0, // upper limit for gradients for clipping.
|
||||
"epochs": 10000, // total number of epochs to train.
|
||||
"lr": 0.002, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
|
||||
// TENSORBOARD and LOGGING
|
||||
"print_step": 25, // Number of steps to log training on console.
|
||||
"tb_plot_step": 100, // Number of steps to plot TB training figures.
|
||||
"print_eval": false, // If True, it prints intermediate loss values in evalulation.
|
||||
"save_step": 5000, // Number of training steps expected to save traninpg stats and checkpoints.
|
||||
"checkpoint": true, // If true, it saves checkpoints per "save_step"
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.:set n
|
||||
"mixed_precision": false,
|
||||
|
||||
// DATA LOADING
|
||||
"text_cleaner": "english_cleaners",
|
||||
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||
"num_loader_workers": 8, // number of training data loader processes. Don't set it too big. 4-8 are good values.
|
||||
"num_val_loader_workers": 8, // number of evaluation data loader processes.
|
||||
"batch_group_size": 0, //Number of batches to shuffle after bucketing.
|
||||
"min_seq_len": 2, // DATASET-RELATED: minimum text length to use in training
|
||||
"max_seq_len": 300, // DATASET-RELATED: maximum text length
|
||||
"compute_f0": false, // compute f0 values in data-loader
|
||||
"compute_input_seq_cache": false, // if true, text sequences are computed before starting training. If phonemes are enabled, they are also computed at this stage.
|
||||
|
||||
// PATHS
|
||||
"output_path": "/home/erogol/Models/ljspeech/",
|
||||
|
||||
// PHONEMES
|
||||
"phoneme_cache_path": "/home/erogol/Models/ljspeech_phonemes/", // phoneme computation is slow, therefore, it caches results in the given folder.
|
||||
"use_phonemes": true, // use phonemes instead of raw characters. It is suggested for better pronoun[ciation.
|
||||
"phoneme_language": "en-us", // depending on your target language, pick one from https://github.com/bootphon/phonemizer#languages
|
||||
|
||||
// MULTI-SPEAKER and GST
|
||||
"use_speaker_embedding": false, // use speaker embedding to enable multi-speaker learning.
|
||||
"use_external_speaker_embedding_file": false, // if true, forces the model to use external embedding per sample instead of nn.embeddings, that is, it supports external embeddings such as those used at: https://arxiv.org/abs /1806.04558
|
||||
"external_speaker_embedding_file": "/home/erogol/Data/libritts/speakers.json", // if not null and use_external_speaker_embedding_file is true, it is used to load a specific embedding file and thus uses these embeddings instead of nn.embeddings, that is, it supports external embeddings such as those used at: https://arxiv.org/abs /1806.04558
|
||||
|
||||
|
||||
// DATASETS
|
||||
"datasets": // List of datasets. They all merged and they get different s$
|
||||
[
|
||||
{
|
||||
"name": "ljspeech",
|
||||
"path": "/home/erogol/Data/LJSpeech-1.1/",
|
||||
"meta_file_train": "metadata.csv",
|
||||
"meta_file_val": null,
|
||||
"meta_file_attn_mask": "/home/erogol/Data/LJSpeech-1.1/metadata_attn_mask.txt" // created by bin/compute_attention_masks.py
|
||||
}
|
||||
]
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.layers.glow_tts.transformer import Transformer
|
||||
from TTS.tts.layers.glow_tts.glow import ConvLayerNorm
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.layers.generic.res_conv_bn import ResidualConvBNBlock, ConvBNBlock
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""Decodes the expanded phoneme encoding into spectrograms
|
||||
Shapes:
|
||||
- input: (B, C, T)
|
||||
"""
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
residual_conv_bn_params={
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4, 8] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 17
|
||||
}):
|
||||
super().__init__()
|
||||
|
||||
self.decoder = ResidualConvBNBlock(hidden_channels,
|
||||
**residual_conv_bn_params)
|
||||
|
||||
self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||
self.post_net = nn.Sequential(
|
||||
ConvBNBlock(hidden_channels, 4, 1, num_conv_blocks=2),
|
||||
nn.Conv1d(hidden_channels, out_channels, 1),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
o = self.decoder(x, x_mask)
|
||||
o = self.post_conv(o) + x
|
||||
return self.post_net(o)
|
|
@ -0,0 +1,27 @@
|
|||
from torch import nn
|
||||
|
||||
from TTS.tts.layers.generic.res_conv_bn import ConvBN
|
||||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
"""Predicts phoneme log durations based on the encoder outputs"""
|
||||
def __init__(self, hidden_channels):
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
ConvBN(hidden_channels, 4, 1),
|
||||
ConvBN(hidden_channels, 3, 1),
|
||||
ConvBN(hidden_channels, 1, 1),
|
||||
nn.Conv1d(hidden_channels, 1, 1)
|
||||
])
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Outputs interpreted as log(durations)
|
||||
To get actual durations, do exp transformation
|
||||
:param x:
|
||||
:return:
|
||||
"""
|
||||
o = x
|
||||
for layer in self.layers:
|
||||
o = layer(o) * x_mask
|
||||
return o
|
|
@ -0,0 +1,153 @@
|
|||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from TTS.tts.layers.glow_tts.transformer import Transformer
|
||||
from TTS.tts.layers.glow_tts.glow import ConvLayerNorm
|
||||
from TTS.tts.layers.generic.res_conv_bn import ResidualConvBNBlock
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
"""Sinusoidal positional encoding for non-recurrent neural networks.
|
||||
Implementation based on "Attention Is All You Need"
|
||||
Args:
|
||||
dropout (float): dropout parameter
|
||||
dim (int): embedding size
|
||||
"""
|
||||
|
||||
def __init__(self, dim, dropout=0.0, max_len=5000):
|
||||
if dim % 2 != 0:
|
||||
raise ValueError("Cannot use sin/cos positional encoding with "
|
||||
"odd dim (got dim={:d})".format(dim))
|
||||
pe = torch.zeros(max_len, dim)
|
||||
position = torch.arange(0, max_len).unsqueeze(1)
|
||||
div_term = torch.exp((torch.arange(0, dim, 2, dtype=torch.float) *
|
||||
-(math.log(10000.0) / dim)))
|
||||
pe[:, 0::2] = torch.sin(position.float() * div_term)
|
||||
pe[:, 1::2] = torch.cos(position.float() * div_term)
|
||||
pe = pe.unsqueeze(0).transpose(1, 2)
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.register_buffer('pe', pe)
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x, step=None):
|
||||
"""Embed inputs.
|
||||
Args:
|
||||
x (FloatTensor): Sequence of word vectors
|
||||
``(seq_len, batch_size, self.dim)``
|
||||
step (int or NoneType): If stepwise (``seq_len = 1``), use
|
||||
the encoding for this position.
|
||||
|
||||
Shapes:
|
||||
x: B x C x T
|
||||
"""
|
||||
|
||||
x = x * math.sqrt(self.dim)
|
||||
if step is None:
|
||||
if self.pe.size(2) < x.size(2):
|
||||
raise RuntimeError(
|
||||
f"Sequence is {x.size(2)} but PositionalEncoding is"
|
||||
f" limited to {self.pe.size(2)}. See max_len argument."
|
||||
)
|
||||
x = x + self.pe[:, : ,:x.size(2)]
|
||||
else:
|
||||
x = x + self.pe[:, :, step]
|
||||
if hasattr(self, 'dropout'):
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
# pylint: disable=dangerous-default-value
|
||||
def __init__(
|
||||
self,
|
||||
hidden_channels,
|
||||
out_channels,
|
||||
encoder_type='residual_conv_bn',
|
||||
encoder_params = {
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 13
|
||||
},
|
||||
c_in_channels=0):
|
||||
"""Speedy-Speech encoder using Transformers or Residual BN Convs internally.
|
||||
|
||||
Args:
|
||||
num_chars (int): number of characters.
|
||||
out_channels (int): number of output channels.
|
||||
hidden_channels (int): encoder's embedding size.
|
||||
encoder_type (str): encoder layer types. 'transformers' or 'residual_conv_bn'. Default 'residual_conv_bn'.
|
||||
encoder_params (dict): model parameters for specified encoder type.
|
||||
c_in_channels (int): number of channels for conditional input.
|
||||
|
||||
Note:
|
||||
Default encoder_params...
|
||||
|
||||
for 'transformer'
|
||||
encoder_params={
|
||||
'hidden_channels_ffn': 768,
|
||||
'num_heads': 2,
|
||||
"kernel_size": 3,
|
||||
"dropout_p": 0.1,
|
||||
"num_layers": 6,
|
||||
"rel_attn_window_size": 4,
|
||||
"input_length": None
|
||||
},
|
||||
|
||||
for 'residual_conv_bn'
|
||||
encoder_params = {
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 13
|
||||
}
|
||||
|
||||
Shapes:
|
||||
- input: (B, C, T)
|
||||
"""
|
||||
super().__init__()
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.encoder_type = encoder_type
|
||||
self.c_in_channels = c_in_channels
|
||||
|
||||
# init encoder
|
||||
if encoder_type.lower() == "transformer":
|
||||
# optional convolutional prenet
|
||||
self.pre = ConvLayerNorm(hidden_channels,
|
||||
hidden_channels,
|
||||
hidden_channels,
|
||||
kernel_size=5,
|
||||
num_layers=3,
|
||||
dropout_p=0.5)
|
||||
# text encoder
|
||||
self.encoder = Transformer(hidden_channels, **encoder_params)
|
||||
elif encoder_type.lower() == 'residual_conv_bn':
|
||||
self.pre = nn.Sequential(
|
||||
nn.Conv1d(hidden_channels, hidden_channels, 1), nn.ReLU())
|
||||
self.encoder = ResidualConvBNBlock(hidden_channels,
|
||||
**encoder_params)
|
||||
else:
|
||||
raise NotImplementedError(' [!] encoder type not implemented.')
|
||||
|
||||
# final projection layers
|
||||
self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1)
|
||||
self.post_bn = nn.BatchNorm1d(hidden_channels)
|
||||
self.post_conv2 = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
|
||||
def forward(self, x, x_mask, g=None):
|
||||
if self.encoder_type == 'transformer':
|
||||
o = self.pre(x, x_mask)
|
||||
else:
|
||||
o = self.pre(x) * x_mask
|
||||
o = self.encoder(o, x_mask)
|
||||
o = self.post_conv(o + x)
|
||||
o = self.post_bn(o)
|
||||
o = F.relu(o)
|
||||
o = self.post_conv2(o)
|
||||
# [B, C, T]
|
||||
return o * x_mask
|
|
@ -0,0 +1,123 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from TTS.tts.layers.speedy_speech.decoder import Decoder
|
||||
from TTS.tts.layers.speedy_speech.duration_predictor import DurationPredictor
|
||||
from TTS.tts.layers.speedy_speech.encoder import Encoder, PositionalEncoding
|
||||
from TTS.tts.utils.generic_utils import sequence_mask
|
||||
from TTS.tts.layers.glow_tts.monotonic_align import generate_path
|
||||
|
||||
|
||||
class SpeedySpeech(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_chars,
|
||||
out_channels,
|
||||
hidden_channels,
|
||||
positional_encoding=True,
|
||||
length_scale=1,
|
||||
encoder_type='residual_conv_bn',
|
||||
encoder_params={
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 13
|
||||
},
|
||||
decoder_residual_conv_bn_params={
|
||||
"kernel_size": 4,
|
||||
"dilations": 4 * [1, 2, 4, 8] + [1],
|
||||
"num_conv_blocks": 2,
|
||||
"num_res_blocks": 17
|
||||
},
|
||||
c_in_channels=0):
|
||||
super().__init__()
|
||||
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
|
||||
self.emb = nn.Embedding(num_chars, hidden_channels)
|
||||
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type,
|
||||
encoder_params, c_in_channels)
|
||||
if positional_encoding:
|
||||
self.pos_encoder = PositionalEncoding(hidden_channels)
|
||||
self.decoder = Decoder(out_channels, hidden_channels,
|
||||
decoder_residual_conv_bn_params)
|
||||
self.duration_predictor = DurationPredictor(hidden_channels)
|
||||
|
||||
def expand_encoder_outputs(self, en, dr, x_mask, y_mask):
|
||||
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
|
||||
attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype)
|
||||
o_en_ex = torch.matmul(
|
||||
attn.squeeze(1).transpose(1, 2), en.transpose(1,
|
||||
2)).transpose(1, 2)
|
||||
return o_en_ex, attn
|
||||
|
||||
def format_durations(self, o_dr_log, x_mask):
|
||||
o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale
|
||||
o_dr[o_dr < 1] = 1.0
|
||||
o_dr = torch.round(o_dr)
|
||||
return o_dr
|
||||
|
||||
def forward(self, x, x_lengths, y_lengths, dr, g=None):
|
||||
"""
|
||||
docstring
|
||||
"""
|
||||
# [B, T, C]
|
||||
x_emb = self.emb(x)
|
||||
# [B, C, T]
|
||||
x_emb = torch.transpose(x_emb, 1, -1)
|
||||
|
||||
# compute sequence masks
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]),
|
||||
1).to(x.dtype)
|
||||
|
||||
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
|
||||
1).to(x_mask.dtype)
|
||||
|
||||
# encoder pass
|
||||
o_en = self.encoder(x_emb, x_mask)
|
||||
|
||||
# duration predictor pass
|
||||
o_dr_log = self.duration_predictor(o_en.detach(), x_mask)
|
||||
|
||||
# expand o_en with durations
|
||||
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
|
||||
|
||||
# positional encoding
|
||||
if hasattr(self, 'pos_encoder'):
|
||||
o_en_ex = self.pos_encoder(o_en_ex)
|
||||
|
||||
# decoder pass
|
||||
o_de = self.decoder(o_en_ex, y_mask)
|
||||
|
||||
return o_de, o_dr_log.squeeze(1), attn.transpose(1, 2)
|
||||
|
||||
def inference(self, x, x_lengths, g=None):
|
||||
# pad input to prevent dropping the last word
|
||||
x = torch.nn.functional.pad(x, pad=(0, 5), mode='constant', value=0)
|
||||
|
||||
# [B, T, C]
|
||||
x_emb = self.emb(x)
|
||||
# [B, C, T]
|
||||
x_emb = torch.transpose(x_emb, 1, -1)
|
||||
|
||||
# compute sequence masks
|
||||
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]),
|
||||
1).to(x.dtype)
|
||||
# encoder pass
|
||||
o_en = self.encoder(x_emb, x_mask)
|
||||
|
||||
# duration predictor pass
|
||||
o_dr_log = self.duration_predictor(o_en.detach(), x_mask)
|
||||
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
|
||||
|
||||
# output mask
|
||||
y_mask = torch.unsqueeze(sequence_mask(o_dr.sum(1), None), 1).to(x_mask.dtype)
|
||||
|
||||
# expand o_en with durations
|
||||
o_en_ex, attn = self.expand_encoder_outputs(o_en, o_dr, x_mask, y_mask)
|
||||
|
||||
# positional encoding
|
||||
if hasattr(self, 'pos_encoder'):
|
||||
o_en_ex = self.pos_encoder(o_en_ex)
|
||||
|
||||
# decoder pass
|
||||
o_de = self.decoder(o_en_ex, y_mask)
|
||||
|
||||
return o_de, attn.transpose(1, 2)
|
Loading…
Reference in New Issue