mirror of https://github.com/coqui-ai/TTS.git
masked loss
parent
0f3b2ddd7b
commit
33937f54d0
|
@ -98,9 +98,6 @@ class LJSpeechDataset(Dataset):
|
||||||
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
mel = [self.ap.melspectrogram(w).astype('float32') for w in wav]
|
||||||
mel_lengths = [m.shape[1] for m in mel]
|
mel_lengths = [m.shape[1] for m in mel]
|
||||||
|
|
||||||
# compute 'stop token' targets
|
|
||||||
stop_targets = [np.array([0.]*mel_len) for mel_len in mel_lengths]
|
|
||||||
|
|
||||||
# PAD sequences with largest length of the batch
|
# PAD sequences with largest length of the batch
|
||||||
text = prepare_data(text).astype(np.int32)
|
text = prepare_data(text).astype(np.int32)
|
||||||
wav = prepare_data(wav)
|
wav = prepare_data(wav)
|
||||||
|
@ -111,9 +108,6 @@ class LJSpeechDataset(Dataset):
|
||||||
assert mel.shape[2] == linear.shape[2]
|
assert mel.shape[2] == linear.shape[2]
|
||||||
timesteps = mel.shape[2]
|
timesteps = mel.shape[2]
|
||||||
|
|
||||||
# PAD stop targets
|
|
||||||
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
|
|
||||||
|
|
||||||
# PAD with zeros that can be divided by outputs per step
|
# PAD with zeros that can be divided by outputs per step
|
||||||
if (timesteps + 1) % self.outputs_per_step != 0:
|
if (timesteps + 1) % self.outputs_per_step != 0:
|
||||||
pad_len = self.outputs_per_step - \
|
pad_len = self.outputs_per_step - \
|
||||||
|
@ -124,7 +118,16 @@ class LJSpeechDataset(Dataset):
|
||||||
linear = pad_per_step(linear, pad_len)
|
linear = pad_per_step(linear, pad_len)
|
||||||
mel = pad_per_step(mel, pad_len)
|
mel = pad_per_step(mel, pad_len)
|
||||||
|
|
||||||
# reshape mojo
|
# update mel lengths
|
||||||
|
mel_lengths = [l+pad_len for l in mel_lengths]
|
||||||
|
|
||||||
|
# compute 'stop token' targets
|
||||||
|
stop_targets = [np.array([0.]*mel_len) for mel_len in mel_lengths]
|
||||||
|
|
||||||
|
# PAD stop targets
|
||||||
|
stop_targets = prepare_stop_target(stop_targets, self.outputs_per_step)
|
||||||
|
|
||||||
|
# B x T x D
|
||||||
linear = linear.transpose(0, 2, 1)
|
linear = linear.transpose(0, 2, 1)
|
||||||
mel = mel.transpose(0, 2, 1)
|
mel = mel.transpose(0, 2, 1)
|
||||||
|
|
||||||
|
@ -133,8 +136,9 @@ class LJSpeechDataset(Dataset):
|
||||||
text = torch.LongTensor(text)
|
text = torch.LongTensor(text)
|
||||||
linear = torch.FloatTensor(linear)
|
linear = torch.FloatTensor(linear)
|
||||||
mel = torch.FloatTensor(mel)
|
mel = torch.FloatTensor(mel)
|
||||||
|
mel_lengths = torch.LongTensor(mel_lengths)
|
||||||
stop_targets = torch.FloatTensor(stop_targets)
|
stop_targets = torch.FloatTensor(stop_targets)
|
||||||
return text, text_lenghts, linear, mel, stop_targets, item_idxs[0]
|
return text, text_lenghts, linear, mel, mel_lengths, stop_targets, item_idxs[0]
|
||||||
|
|
||||||
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
raise TypeError(("batch must contain tensors, numbers, dicts or lists;\
|
||||||
found {}"
|
found {}"
|
||||||
|
|
|
@ -37,16 +37,12 @@ class DecoderTests(unittest.TestCase):
|
||||||
dummy_memory = T.autograd.Variable(T.rand(4, 120, 32))
|
dummy_memory = T.autograd.Variable(T.rand(4, 120, 32))
|
||||||
|
|
||||||
print(layer)
|
print(layer)
|
||||||
output, alignment, stop_output = layer(dummy_input, dummy_memory)
|
output, alignment = layer(dummy_input, dummy_memory)
|
||||||
print(output.shape)
|
print(output.shape)
|
||||||
print(" > Stop ", stop_output.shape)
|
|
||||||
|
|
||||||
assert output.shape[0] == 4
|
assert output.shape[0] == 4
|
||||||
assert output.shape[1] == 120 / 5
|
assert output.shape[1] == 120 / 5
|
||||||
assert output.shape[2] == 32 * 5
|
assert output.shape[2] == 32 * 5
|
||||||
assert stop_output.shape[0] == 4
|
|
||||||
assert stop_output.shape[1] == 120 / 5
|
|
||||||
assert stop_output.shape[2] == 5
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderTests(unittest.TestCase):
|
class EncoderTests(unittest.TestCase):
|
||||||
|
|
|
@ -43,8 +43,9 @@ class TestDataset(unittest.TestCase):
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
stop_targets = data[4]
|
mel_lengths = data[4]
|
||||||
item_idx = data[5]
|
stop_target = data[5]
|
||||||
|
item_idx = data[6]
|
||||||
|
|
||||||
neg_values = text_input[text_input < 0]
|
neg_values = text_input[text_input < 0]
|
||||||
check_count = len(neg_values)
|
check_count = len(neg_values)
|
||||||
|
@ -82,8 +83,9 @@ class TestDataset(unittest.TestCase):
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
stop_target = data[4]
|
mel_lengths = data[4]
|
||||||
item_idx = data[5]
|
stop_target = data[5]
|
||||||
|
item_idx = data[6]
|
||||||
|
|
||||||
# check the last time step to be zero padded
|
# check the last time step to be zero padded
|
||||||
assert mel_input[0, -1].sum() == 0
|
assert mel_input[0, -1].sum() == 0
|
||||||
|
@ -92,6 +94,10 @@ class TestDataset(unittest.TestCase):
|
||||||
assert linear_input[0, -2].sum() != 0
|
assert linear_input[0, -2].sum() != 0
|
||||||
assert stop_target[0, -1] == 1
|
assert stop_target[0, -1] == 1
|
||||||
assert stop_target.sum() == 1
|
assert stop_target.sum() == 1
|
||||||
|
assert len(mel_lengths.shape) == 1
|
||||||
|
print(mel_lengths)
|
||||||
|
print(mel_input)
|
||||||
|
assert mel_lengths[0] == mel_input[0].shape[0]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
21
train.py
21
train.py
|
@ -26,6 +26,7 @@ from utils.model import get_param_size
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from datasets.LJSpeech import LJSpeechDataset
|
from datasets.LJSpeech import LJSpeechDataset
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
|
from losses import
|
||||||
|
|
||||||
|
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
@ -80,6 +81,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
|
mel_lengths = data[4]
|
||||||
|
|
||||||
current_step = num_iter + args.restore_step + epoch * len(data_loader) + 1
|
current_step = num_iter + args.restore_step + epoch * len(data_loader) + 1
|
||||||
|
|
||||||
|
@ -93,6 +95,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
# convert inputs to variables
|
# convert inputs to variables
|
||||||
text_input_var = Variable(text_input)
|
text_input_var = Variable(text_input)
|
||||||
mel_spec_var = Variable(mel_input)
|
mel_spec_var = Variable(mel_input)
|
||||||
|
mel_length_var = Variable(mel_lengths)
|
||||||
linear_spec_var = Variable(linear_input, volatile=True)
|
linear_spec_var = Variable(linear_input, volatile=True)
|
||||||
|
|
||||||
# sort sequence by length for curriculum learning
|
# sort sequence by length for curriculum learning
|
||||||
|
@ -108,6 +111,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
text_input_var = text_input_var.cuda()
|
text_input_var = text_input_var.cuda()
|
||||||
mel_spec_var = mel_spec_var.cuda()
|
mel_spec_var = mel_spec_var.cuda()
|
||||||
|
mel_lengths_var = mel_lengths_var.cuda()
|
||||||
linear_spec_var = linear_spec_var.cuda()
|
linear_spec_var = linear_spec_var.cuda()
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
|
@ -115,10 +119,11 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
||||||
model.forward(text_input_var, mel_spec_var)
|
model.forward(text_input_var, mel_spec_var)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
mel_loss = criterion(mel_output, mel_spec_var)
|
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths)
|
||||||
linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \
|
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths) \
|
||||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||||
linear_spec_var[: ,: ,:n_priority_freq])
|
linear_spec_var[: ,: ,:n_priority_freq],
|
||||||
|
mel_lengths)
|
||||||
loss = mel_loss + linear_loss
|
loss = mel_loss + linear_loss
|
||||||
|
|
||||||
# backpass and check the grad norm
|
# backpass and check the grad norm
|
||||||
|
@ -215,26 +220,30 @@ def evaluate(model, criterion, data_loader, current_step):
|
||||||
text_lengths = data[1]
|
text_lengths = data[1]
|
||||||
linear_input = data[2]
|
linear_input = data[2]
|
||||||
mel_input = data[3]
|
mel_input = data[3]
|
||||||
|
mel_lengths = data[4]
|
||||||
|
|
||||||
# convert inputs to variables
|
# convert inputs to variables
|
||||||
text_input_var = Variable(text_input)
|
text_input_var = Variable(text_input)
|
||||||
mel_spec_var = Variable(mel_input)
|
mel_spec_var = Variable(mel_input)
|
||||||
|
mel_lengths_var = Variable(mel_lengths)
|
||||||
linear_spec_var = Variable(linear_input, volatile=True)
|
linear_spec_var = Variable(linear_input, volatile=True)
|
||||||
|
|
||||||
# dispatch data to GPU
|
# dispatch data to GPU
|
||||||
if use_cuda:
|
if use_cuda:
|
||||||
text_input_var = text_input_var.cuda()
|
text_input_var = text_input_var.cuda()
|
||||||
mel_spec_var = mel_spec_var.cuda()
|
mel_spec_var = mel_spec_var.cuda()
|
||||||
|
mel_lengths_var = mel_lengths_var.cuda()
|
||||||
linear_spec_var = linear_spec_var.cuda()
|
linear_spec_var = linear_spec_var.cuda()
|
||||||
|
|
||||||
# forward pass
|
# forward pass
|
||||||
mel_output, linear_output, alignments = model.forward(text_input_var, mel_spec_var)
|
mel_output, linear_output, alignments = model.forward(text_input_var, mel_spec_var)
|
||||||
|
|
||||||
# loss computation
|
# loss computation
|
||||||
mel_loss = criterion(mel_output, mel_spec_var)
|
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths)
|
||||||
linear_loss = 0.5 * criterion(linear_output, linear_spec_var) \
|
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths) \
|
||||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||||
linear_spec_var[: ,: ,:n_priority_freq])
|
linear_spec_var[: ,: ,:n_priority_freq],
|
||||||
|
mel_lengths)
|
||||||
loss = mel_loss + linear_loss
|
loss = mel_loss + linear_loss
|
||||||
|
|
||||||
step_time = time.time() - start_time
|
step_time = time.time() - start_time
|
||||||
|
|
Loading…
Reference in New Issue