mirror of https://github.com/coqui-ai/TTS.git
Fix UnivNet inference code
parent
168f97cbe9
commit
9e7824fe35
|
@ -15,7 +15,7 @@ def get_padding(k, d):
|
|||
class ResBlock1(torch.nn.Module):
|
||||
"""Residual Block Type 1. It has 3 convolutional layers in each convolutiona block.
|
||||
|
||||
Network:
|
||||
Network::
|
||||
|
||||
x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o
|
||||
|--------------------------------------------------------------------------------------------------|
|
||||
|
@ -105,7 +105,7 @@ class ResBlock1(torch.nn.Module):
|
|||
class ResBlock2(torch.nn.Module):
|
||||
"""Residual Block Type 1. It has 3 convolutional layers in each convolutiona block.
|
||||
|
||||
Network:
|
||||
Network::
|
||||
|
||||
x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o
|
||||
|---------------------------------------------------|
|
||||
|
|
|
@ -122,24 +122,16 @@ class UnivnetGenerator(torch.nn.Module):
|
|||
"""Return receptive field size."""
|
||||
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
||||
|
||||
def inference(self, c=None, x=None):
|
||||
@torch.no_grad()
|
||||
def inference(self, c):
|
||||
"""Perform inference.
|
||||
Args:
|
||||
c (Union[Tensor, ndarray]): Local conditioning auxiliary features (T' ,C).
|
||||
x (Union[Tensor, ndarray]): Input noise signal (T, 1).
|
||||
c (Tensor): Local conditioning auxiliary features :math:`(B, C, T)`.
|
||||
Returns:
|
||||
Tensor: Output tensor (T, out_channels)
|
||||
"""
|
||||
if x is not None:
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = torch.tensor(x, dtype=torch.float).to(next(self.parameters()).device)
|
||||
x = x.transpose(1, 0).unsqueeze(0)
|
||||
else:
|
||||
assert c is not None
|
||||
x = torch.randn(1, 1, len(c) * self.upsample_factor).to(next(self.parameters()).device)
|
||||
if c is not None:
|
||||
if not isinstance(c, torch.Tensor):
|
||||
c = torch.tensor(c, dtype=torch.float).to(next(self.parameters()).device)
|
||||
c = c.transpose(1, 0).unsqueeze(0)
|
||||
c = torch.nn.ReplicationPad1d(self.aux_context_window)(c)
|
||||
return self.forward(c).squeeze(0).transpose(1, 0)
|
||||
x = torch.randn([c.shape[0], self.in_channels, c.shape[2]])
|
||||
x = x.to(self.first_conv.bias.device)
|
||||
|
||||
c = c.to(next(self.parameters()))
|
||||
return self.forward(c)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
|
||||
from TTS.config.shared_configs import BaseAudioConfig
|
||||
from TTS.trainer import Trainer, TrainingArgs, init_training
|
||||
from TTS.vocoder.configs import UnivnetConfig
|
||||
|
||||
|
|
Loading…
Reference in New Issue