Remove variables

pull/10/head
Eren Golge 2018-04-25 08:00:19 -07:00
parent 28baf66ae8
commit 52b4bc6bed
5 changed files with 1 additions and 9 deletions

View File

@ -74,7 +74,7 @@ class LJSpeechDataset(Dataset):
def get_dummy_data(self):
r"""Get a dummy input for testing"""
return torch.autograd.Variable(torch.ones(16, 143)).type(torch.LongTensor)
return torch.ones(16, 143).type(torch.LongTensor)
def collate_fn(self, batch):
r"""

View File

@ -72,10 +72,6 @@ class TWEBDataset(Dataset):
sample = {'text': text, 'wav': wav, 'item_idx': self.frames[idx][0]}
return sample
def get_dummy_data(self):
r"""Get a dummy input for testing"""
return torch.autograd.Variable(torch.ones(16, 143)).type(torch.LongTensor)
def collate_fn(self, batch):
r"""
Perform preprocessing and create a final data batch:

View File

@ -1,5 +1,4 @@
import torch
from torch.autograd import Variable
from torch import nn
from torch.nn import functional as F

View File

@ -1,6 +1,5 @@
# coding: utf-8
import torch
from torch.autograd import Variable
from torch import nn

View File

@ -1,6 +1,5 @@
import torch
from torch.nn import functional
from torch.autograd import Variable
from torch import nn
@ -11,7 +10,6 @@ def _sequence_mask(sequence_length, max_len=None):
batch_size = sequence_length.size(0)
seq_range = torch.arange(0, max_len).long()
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_range_expand = Variable(seq_range_expand)
if sequence_length.is_cuda:
seq_range_expand = seq_range_expand.cuda()
seq_length_expand = (sequence_length.unsqueeze(1)