mirror of https://github.com/coqui-ai/TTS.git
Mask inputs by length to reduce the effetc on attention module
parent
4e0ab65bbf
commit
966d567540
|
@ -3,6 +3,7 @@ import os
|
|||
import numpy as np
|
||||
import collections
|
||||
import librosa
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from TTS.utils.text import text_to_sequence
|
||||
|
@ -45,6 +46,9 @@ class LJSpeechDataset(Dataset):
|
|||
sample = {'text': text, 'wav': wav}
|
||||
return sample
|
||||
|
||||
def get_dummy_data(self):
|
||||
return torch.autograd.Variable(torch.ones(16, 143)).type(torch.LongTensor)
|
||||
|
||||
def collate_fn(self, batch):
|
||||
|
||||
# Puts each data field into a tensor with outer dimension batch size
|
||||
|
@ -73,7 +77,7 @@ class LJSpeechDataset(Dataset):
|
|||
magnitude = magnitude.transpose(0, 2, 1)
|
||||
mel = mel.transpose(0, 2, 1)
|
||||
|
||||
return text, magnitude, mel
|
||||
return text, text_lenghts, magnitude, mel
|
||||
|
||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||
found {}"
|
||||
|
|
35
train.py
35
train.py
|
@ -12,6 +12,7 @@ import numpy as np
|
|||
|
||||
import torch.nn as nn
|
||||
from torch import optim
|
||||
from torch import onnx
|
||||
from torch.autograd import Variable
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
|
@ -52,6 +53,7 @@ def main(args):
|
|||
sys.exit(1)
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
# Setup the dataset
|
||||
dataset = LJSpeechDataset(os.path.join(c.data_path, 'metadata.csv'),
|
||||
os.path.join(c.data_path, 'wavs'),
|
||||
c.r,
|
||||
|
@ -67,11 +69,25 @@ def main(args):
|
|||
c.power
|
||||
)
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=c.batch_size,
|
||||
shuffle=True, collate_fn=dataset.collate_fn,
|
||||
drop_last=True, num_workers=c.num_loader_workers)
|
||||
|
||||
# setup the model
|
||||
model = Tacotron(c.embedding_size,
|
||||
c.hidden_size,
|
||||
c.num_mels,
|
||||
c.num_freq,
|
||||
c.r)
|
||||
|
||||
# plot model on tensorboard
|
||||
dummy_input = dataset.get_dummy_data()
|
||||
|
||||
## TODO: onnx does not support RNN fully yet
|
||||
# model_proto_path = os.path.join(OUT_PATH, "model.proto")
|
||||
# onnx.export(model, dummy_input, model_proto_path, verbose=True)
|
||||
# tb.add_graph_onnx(model_proto_path)
|
||||
|
||||
if use_cuda:
|
||||
model = nn.DataParallel(model.cuda())
|
||||
|
||||
|
@ -105,9 +121,6 @@ def main(args):
|
|||
epoch_time = 0
|
||||
for epoch in range(c.epochs):
|
||||
|
||||
dataloader = DataLoader(dataset, batch_size=c.batch_size,
|
||||
shuffle=True, collate_fn=dataset.collate_fn,
|
||||
drop_last=True, num_workers=c.num_loader_workers)
|
||||
print("\n | > Epoch {}/{}".format(epoch, c.epochs))
|
||||
progbar = Progbar(len(dataset) / c.batch_size)
|
||||
|
||||
|
@ -115,8 +128,9 @@ def main(args):
|
|||
start_time = time.time()
|
||||
|
||||
text_input = data[0]
|
||||
magnitude_input = data[1]
|
||||
mel_input = data[2]
|
||||
text_lengths = data[1]
|
||||
magnitude_input = data[2]
|
||||
mel_input = data[3]
|
||||
|
||||
current_step = i + args.restore_step + epoch * len(dataloader) + 1
|
||||
|
||||
|
@ -137,6 +151,8 @@ def main(args):
|
|||
if use_cuda:
|
||||
text_input_var = Variable(torch.from_numpy(text_input).type(
|
||||
torch.cuda.LongTensor)).cuda()
|
||||
text_lengths_var = Variable(torch.from_numpy(test_lengths).type(
|
||||
torch.cuda.LongTensor)).cuda()
|
||||
mel_input_var = Variable(torch.from_numpy(mel_input).type(
|
||||
torch.cuda.FloatTensor)).cuda()
|
||||
mel_spec_var = Variable(torch.from_numpy(mel_input).type(
|
||||
|
@ -147,6 +163,8 @@ def main(args):
|
|||
else:
|
||||
text_input_var = Variable(torch.from_numpy(text_input).type(
|
||||
torch.LongTensor),)
|
||||
text_lengths_var = Variable(torch.from_numpy(test_lengths).type(
|
||||
torch.LongTensor))
|
||||
mel_input_var = Variable(torch.from_numpy(mel_input).type(
|
||||
torch.FloatTensor))
|
||||
mel_spec_var = Variable(torch.from_numpy(
|
||||
|
@ -155,7 +173,7 @@ def main(args):
|
|||
magnitude_input).type(torch.FloatTensor))
|
||||
|
||||
mel_output, linear_output, alignments =\
|
||||
model.forward(text_input_var, mel_input_var)
|
||||
model.forward(text_input_var, mel_input_var, input_lengths=input_lengths_var)
|
||||
|
||||
mel_loss = criterion(mel_output, mel_spec_var)
|
||||
#linear_loss = torch.abs(linear_output - linear_spec_var)
|
||||
|
@ -169,7 +187,7 @@ def main(args):
|
|||
# loss = loss.cuda()
|
||||
|
||||
loss.backward()
|
||||
grad_norm = nn.utils.clip_grad_norm(model.parameters(), 1.)
|
||||
grad_norm = nn.utils.clip_grad_norm(model.parameters(), 1.) ## TODO: maybe no need
|
||||
optimizer.step()
|
||||
|
||||
step_time = time.time() - start_time
|
||||
|
@ -180,11 +198,12 @@ def main(args):
|
|||
('mel_loss', mel_loss.data[0]),
|
||||
('grad_norm', grad_norm)])
|
||||
|
||||
|
||||
# Plot Learning Stats
|
||||
tb.add_scalar('Loss/TotalLoss', loss.data[0], current_step)
|
||||
tb.add_scalar('Loss/LinearLoss', linear_loss.data[0],
|
||||
current_step)
|
||||
tb.add_scalar('Loss/MelLoss', mel_loss.data[0], current_step)
|
||||
|
||||
tb.add_scalar('Params/LearningRate', optimizer.param_groups[0]['lr'],
|
||||
current_step)
|
||||
tb.add_scalar('Params/GradNorm', grad_norm, current_step)
|
||||
|
|
Loading…
Reference in New Issue