Fix UnivNet inference code

pull/602/head
Eren Gölge 2021-07-02 10:48:34 +02:00
parent 168f97cbe9
commit 9e7824fe35
3 changed files with 11 additions and 18 deletions

View File

@ -15,7 +15,7 @@ def get_padding(k, d):
class ResBlock1(torch.nn.Module):
"""Residual Block Type 1. It has 3 convolutional layers in each convolutiona block.
Network:
Network::
x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o
|--------------------------------------------------------------------------------------------------|
@ -105,7 +105,7 @@ class ResBlock1(torch.nn.Module):
class ResBlock2(torch.nn.Module):
"""Residual Block Type 1. It has 3 convolutional layers in each convolutiona block.
Network:
Network::
x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o
|---------------------------------------------------|

View File

@ -122,24 +122,16 @@ class UnivnetGenerator(torch.nn.Module):
"""Return receptive field size."""
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
def inference(self, c=None, x=None):
@torch.no_grad()
def inference(self, c):
"""Perform inference.
Args:
c (Union[Tensor, ndarray]): Local conditioning auxiliary features (T' ,C).
x (Union[Tensor, ndarray]): Input noise signal (T, 1).
c (Tensor): Local conditioning auxiliary features :math:`(B, C, T)`.
Returns:
Tensor: Output tensor (T, out_channels)
"""
if x is not None:
if not isinstance(x, torch.Tensor):
x = torch.tensor(x, dtype=torch.float).to(next(self.parameters()).device)
x = x.transpose(1, 0).unsqueeze(0)
else:
assert c is not None
x = torch.randn(1, 1, len(c) * self.upsample_factor).to(next(self.parameters()).device)
if c is not None:
if not isinstance(c, torch.Tensor):
c = torch.tensor(c, dtype=torch.float).to(next(self.parameters()).device)
c = c.transpose(1, 0).unsqueeze(0)
c = torch.nn.ReplicationPad1d(self.aux_context_window)(c)
return self.forward(c).squeeze(0).transpose(1, 0)
x = torch.randn([c.shape[0], self.in_channels, c.shape[2]])
x = x.to(self.first_conv.bias.device)
c = c.to(next(self.parameters()))
return self.forward(c)

View File

@ -1,5 +1,6 @@
import os
from TTS.config.shared_configs import BaseAudioConfig
from TTS.trainer import Trainer, TrainingArgs, init_training
from TTS.vocoder.configs import UnivnetConfig