mirror of https://github.com/coqui-ai/TTS.git
change wavernn generate to inference
parent
9b0f441945
commit
25551c4634
|
@ -260,17 +260,18 @@ class WaveRNN(nn.Module):
|
|||
x = F.relu(self.fc2(x))
|
||||
return self.fc3(x)
|
||||
|
||||
def generate(self, mels, batched, target, overlap, use_cuda=False):
|
||||
def inference(self, mels, batched, target, overlap):
|
||||
|
||||
self.eval()
|
||||
device = 'cuda' if use_cuda else 'cpu'
|
||||
device = mels.device
|
||||
output = []
|
||||
start = time.time()
|
||||
rnn1 = self.get_gru_cell(self.rnn1)
|
||||
rnn2 = self.get_gru_cell(self.rnn2)
|
||||
|
||||
with torch.no_grad():
|
||||
mels = torch.FloatTensor(mels).unsqueeze(0).to(device)
|
||||
if isinstance(mels, np.ndarray):
|
||||
mels = torch.FloatTensor(mels).unsqueeze(0).to(device)
|
||||
#mels = torch.FloatTensor(mels).cuda().unsqueeze(0)
|
||||
wave_len = (mels.size(-1) - 1) * self.hop_length
|
||||
mels = self.pad_tensor(mels.transpose(
|
||||
|
|
Loading…
Reference in New Issue