refactor(wavernn): remove duplicate Stretch2d

I checked that the implementations are the same
pull/4115/head^2
Enno Hermann 2024-11-22 01:16:42 +01:00
parent e63962c226
commit 2e5f68df6a
1 changed files with 1 additions and 13 deletions

View File

@ -17,6 +17,7 @@ from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import mulaw_decode
from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset
from TTS.vocoder.layers.losses import WaveRNNLoss
from TTS.vocoder.layers.upsample import Stretch2d
from TTS.vocoder.models.base_vocoder import BaseVocoder
from TTS.vocoder.utils.distribution import sample_from_discretized_mix_logistic, sample_from_gaussian
@ -66,19 +67,6 @@ class MelResNet(nn.Module):
return x
class Stretch2d(nn.Module):
def __init__(self, x_scale, y_scale):
super().__init__()
self.x_scale = x_scale
self.y_scale = y_scale
def forward(self, x):
b, c, h, w = x.size()
x = x.unsqueeze(-1).unsqueeze(3)
x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale)
return x.view(b, c, h * self.y_scale, w * self.x_scale)
class UpsampleNetwork(nn.Module):
def __init__(
self,