mirror of https://github.com/coqui-ai/TTS.git
refactor(wavernn): remove duplicate Stretch2d
I checked that the implementations are the samepull/4115/head^2
parent
e63962c226
commit
2e5f68df6a
|
@ -17,6 +17,7 @@ from TTS.utils.audio import AudioProcessor
|
|||
from TTS.utils.audio.numpy_transforms import mulaw_decode
|
||||
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
|
||||
from TTS.vocoder.layers.losses import WaveRNNLoss
|
||||
from TTS.vocoder.layers.upsample import Stretch2d
|
||||
from TTS.vocoder.models.base_vocoder import BaseVocoder
|
||||
from TTS.vocoder.utils.distribution import sample_from_discretized_mix_logistic, sample_from_gaussian
|
||||
|
||||
|
@ -66,19 +67,6 @@ class MelResNet(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class Stretch2d(nn.Module):
|
||||
def __init__(self, x_scale, y_scale):
|
||||
super().__init__()
|
||||
self.x_scale = x_scale
|
||||
self.y_scale = y_scale
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.size()
|
||||
x = x.unsqueeze(-1).unsqueeze(3)
|
||||
x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
|
||||
return x.view(b, c, h * self.y_scale, w * self.x_scale)
|
||||
|
||||
|
||||
class UpsampleNetwork(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue