mirror of https://github.com/coqui-ai/TTS.git
remove 'Variable' from train.py
parent
b087c0b5ec
commit
a2eeec31be
53
train.py
53
train.py
|
@ -13,7 +13,6 @@ import numpy as np
|
|||
import torch.nn as nn
|
||||
from torch import optim
|
||||
from torch import onnx
|
||||
from torch.autograd import Variable
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||||
from tensorboardX import SummaryWriter
|
||||
|
@ -93,29 +92,23 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
|||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# convert inputs to variables
|
||||
text_input_var = Variable(text_input)
|
||||
mel_spec_var = Variable(mel_input)
|
||||
mel_lengths_var = Variable(mel_lengths)
|
||||
linear_spec_var = Variable(linear_input, volatile=True)
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
text_input_var = text_input_var.cuda()
|
||||
mel_spec_var = mel_spec_var.cuda()
|
||||
mel_lengths_var = mel_lengths_var.cuda()
|
||||
linear_spec_var = linear_spec_var.cuda()
|
||||
text_input = text_input.cuda()
|
||||
mel_input = mel_input.cuda()
|
||||
mel_lengths = mel_lengths.cuda()
|
||||
linear_input = linear_input.cuda()
|
||||
|
||||
# forward pass
|
||||
mel_output, linear_output, alignments =\
|
||||
model.forward(text_input_var, mel_spec_var)
|
||||
model.forward(text_input, mel_input)
|
||||
|
||||
# loss computation
|
||||
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var)
|
||||
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \
|
||||
mel_loss = criterion(mel_output, mel_input, mel_lengths)
|
||||
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
|
||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||
linear_spec_var[:, :, :n_priority_freq],
|
||||
mel_lengths_var)
|
||||
linear_input[:, :, :n_priority_freq],
|
||||
mel_lengths)
|
||||
loss = mel_loss + linear_loss
|
||||
|
||||
# backpass and check the grad norm
|
||||
|
@ -157,7 +150,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
|
|||
|
||||
# Diagnostic visualizations
|
||||
const_spec = linear_output[0].data.cpu().numpy()
|
||||
gt_spec = linear_spec_var[0].data.cpu().numpy()
|
||||
gt_spec = linear_input[0].data.cpu().numpy()
|
||||
|
||||
const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)
|
||||
gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap)
|
||||
|
@ -215,29 +208,23 @@ def evaluate(model, criterion, data_loader, current_step):
|
|||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
|
||||
# convert inputs to variables
|
||||
text_input_var = Variable(text_input)
|
||||
mel_spec_var = Variable(mel_input)
|
||||
mel_lengths_var = Variable(mel_lengths)
|
||||
linear_spec_var = Variable(linear_input, volatile=True)
|
||||
|
||||
# dispatch data to GPU
|
||||
if use_cuda:
|
||||
text_input_var = text_input_var.cuda()
|
||||
mel_spec_var = mel_spec_var.cuda()
|
||||
mel_lengths_var = mel_lengths_var.cuda()
|
||||
linear_spec_var = linear_spec_var.cuda()
|
||||
text_input = text_input.cuda()
|
||||
mel_input = mel_input.cuda()
|
||||
mel_lengths = mel_lengths.cuda()
|
||||
linear_input = linear_input.cuda()
|
||||
|
||||
# forward pass
|
||||
mel_output, linear_output, alignments =\
|
||||
model.forward(text_input_var, mel_spec_var)
|
||||
model.forward(text_input, mel_input)
|
||||
|
||||
# loss computation
|
||||
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var)
|
||||
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \
|
||||
mel_loss = criterion(mel_output, mel_input, mel_lengths)
|
||||
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
|
||||
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
|
||||
linear_spec_var[:, :, :n_priority_freq],
|
||||
mel_lengths_var)
|
||||
linear_input[:, :, :n_priority_freq],
|
||||
mel_lengths)
|
||||
loss = mel_loss + linear_loss
|
||||
|
||||
step_time = time.time() - start_time
|
||||
|
@ -255,7 +242,7 @@ def evaluate(model, criterion, data_loader, current_step):
|
|||
# Diagnostic visualizations
|
||||
idx = np.random.randint(mel_input.shape[0])
|
||||
const_spec = linear_output[idx].data.cpu().numpy()
|
||||
gt_spec = linear_spec_var[idx].data.cpu().numpy()
|
||||
gt_spec = linear_input[idx].data.cpu().numpy()
|
||||
align_img = alignments[idx].data.cpu().numpy()
|
||||
|
||||
const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)
|
||||
|
|
Loading…
Reference in New Issue