change wavernn generate to inference

pull/10/head
erogol 2020-11-12 12:52:52 +01:00
parent 9b0f441945
commit 25551c4634
1 changed files with 4 additions and 3 deletions

View File

@ -260,17 +260,18 @@ class WaveRNN(nn.Module):
x = F.relu(self.fc2(x))
return self.fc3(x)
def generate(self, mels, batched, target, overlap, use_cuda=False):
def inference(self, mels, batched, target, overlap):
self.eval()
device = 'cuda' if use_cuda else 'cpu'
device = mels.device
output = []
start = time.time()
rnn1 = self.get_gru_cell(self.rnn1)
rnn2 = self.get_gru_cell(self.rnn2)
with torch.no_grad():
mels = torch.FloatTensor(mels).unsqueeze(0).to(device)
if isinstance(mels, np.ndarray):
mels = torch.FloatTensor(mels).unsqueeze(0).to(device)
#mels = torch.FloatTensor(mels).cuda().unsqueeze(0)
wave_len = (mels.size(-1) - 1) * self.hop_length
mels = self.pad_tensor(mels.transpose(