mirror of https://github.com/coqui-ai/TTS.git
change wavernn generate to inference
parent
9b0f441945
commit
25551c4634
|
@ -260,17 +260,18 @@ class WaveRNN(nn.Module):
|
||||||
x = F.relu(self.fc2(x))
|
x = F.relu(self.fc2(x))
|
||||||
return self.fc3(x)
|
return self.fc3(x)
|
||||||
|
|
||||||
def generate(self, mels, batched, target, overlap, use_cuda=False):
|
def inference(self, mels, batched, target, overlap):
|
||||||
|
|
||||||
self.eval()
|
self.eval()
|
||||||
device = 'cuda' if use_cuda else 'cpu'
|
device = mels.device
|
||||||
output = []
|
output = []
|
||||||
start = time.time()
|
start = time.time()
|
||||||
rnn1 = self.get_gru_cell(self.rnn1)
|
rnn1 = self.get_gru_cell(self.rnn1)
|
||||||
rnn2 = self.get_gru_cell(self.rnn2)
|
rnn2 = self.get_gru_cell(self.rnn2)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
mels = torch.FloatTensor(mels).unsqueeze(0).to(device)
|
if isinstance(mels, np.ndarray):
|
||||||
|
mels = torch.FloatTensor(mels).unsqueeze(0).to(device)
|
||||||
#mels = torch.FloatTensor(mels).cuda().unsqueeze(0)
|
#mels = torch.FloatTensor(mels).cuda().unsqueeze(0)
|
||||||
wave_len = (mels.size(-1) - 1) * self.hop_length
|
wave_len = (mels.size(-1) - 1) * self.hop_length
|
||||||
mels = self.pad_tensor(mels.transpose(
|
mels = self.pad_tensor(mels.transpose(
|
||||||
|
|
Loading…
Reference in New Issue