remove 'Variable' from train.py

pull/10/head
Eren Golge 2018-05-10 16:05:03 -07:00
parent b087c0b5ec
commit a2eeec31be
1 changed files with 20 additions and 33 deletions

View File

@ -13,7 +13,6 @@ import numpy as np
import torch.nn as nn
from torch import optim
from torch import onnx
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tensorboardX import SummaryWriter
@ -93,29 +92,23 @@ def train(model, criterion, data_loader, optimizer, epoch):
optimizer.zero_grad()
# convert inputs to variables
text_input_var = Variable(text_input)
mel_spec_var = Variable(mel_input)
mel_lengths_var = Variable(mel_lengths)
linear_spec_var = Variable(linear_input, volatile=True)
# dispatch data to GPU
if use_cuda:
text_input_var = text_input_var.cuda()
mel_spec_var = mel_spec_var.cuda()
mel_lengths_var = mel_lengths_var.cuda()
linear_spec_var = linear_spec_var.cuda()
text_input = text_input.cuda()
mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda()
linear_input = linear_input.cuda()
# forward pass
mel_output, linear_output, alignments =\
model.forward(text_input_var, mel_spec_var)
model.forward(text_input, mel_input)
# loss computation
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var)
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \
mel_loss = criterion(mel_output, mel_input, mel_lengths)
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_spec_var[:, :, :n_priority_freq],
mel_lengths_var)
linear_input[:, :, :n_priority_freq],
mel_lengths)
loss = mel_loss + linear_loss
# backpass and check the grad norm
@ -157,7 +150,7 @@ def train(model, criterion, data_loader, optimizer, epoch):
# Diagnostic visualizations
const_spec = linear_output[0].data.cpu().numpy()
gt_spec = linear_spec_var[0].data.cpu().numpy()
gt_spec = linear_input[0].data.cpu().numpy()
const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)
gt_spec = plot_spectrogram(gt_spec, data_loader.dataset.ap)
@ -215,29 +208,23 @@ def evaluate(model, criterion, data_loader, current_step):
mel_input = data[3]
mel_lengths = data[4]
# convert inputs to variables
text_input_var = Variable(text_input)
mel_spec_var = Variable(mel_input)
mel_lengths_var = Variable(mel_lengths)
linear_spec_var = Variable(linear_input, volatile=True)
# dispatch data to GPU
if use_cuda:
text_input_var = text_input_var.cuda()
mel_spec_var = mel_spec_var.cuda()
mel_lengths_var = mel_lengths_var.cuda()
linear_spec_var = linear_spec_var.cuda()
text_input = text_input.cuda()
mel_input = mel_input.cuda()
mel_lengths = mel_lengths.cuda()
linear_input = linear_input.cuda()
# forward pass
mel_output, linear_output, alignments =\
model.forward(text_input_var, mel_spec_var)
model.forward(text_input, mel_input)
# loss computation
mel_loss = criterion(mel_output, mel_spec_var, mel_lengths_var)
linear_loss = 0.5 * criterion(linear_output, linear_spec_var, mel_lengths_var) \
mel_loss = criterion(mel_output, mel_input, mel_lengths)
linear_loss = 0.5 * criterion(linear_output, linear_input, mel_lengths) \
+ 0.5 * criterion(linear_output[:, :, :n_priority_freq],
linear_spec_var[:, :, :n_priority_freq],
mel_lengths_var)
linear_input[:, :, :n_priority_freq],
mel_lengths)
loss = mel_loss + linear_loss
step_time = time.time() - start_time
@ -255,7 +242,7 @@ def evaluate(model, criterion, data_loader, current_step):
# Diagnostic visualizations
idx = np.random.randint(mel_input.shape[0])
const_spec = linear_output[idx].data.cpu().numpy()
gt_spec = linear_spec_var[idx].data.cpu().numpy()
gt_spec = linear_input[idx].data.cpu().numpy()
align_img = alignments[idx].data.cpu().numpy()
const_spec = plot_spectrogram(const_spec, data_loader.dataset.ap)