From 903a77c1979c105a7fa3bb7d6b64668171781cf0 Mon Sep 17 00:00:00 2001 From: p0p4k Date: Mon, 1 Aug 2022 19:20:37 +0900 Subject: [PATCH] Update wavenet.py (#1796) * Update wavenet.py Current version does not use "in_channels" argument. In glowTTS, we use normalizing flows and so "input dim" == "ouput dim" (channels and length). So, the existing code just uses hidden_channel sized tensor as input to first layer as well as outputs hidden_channel sized tensor. However, since it is a generic implementation, I believe it is better to update it for a more general use. * "in_channels -> hidden_channels" --- TTS/tts/layers/generic/wavenet.py | 11 ++++++++--- TTS/tts/layers/glow_tts/glow.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/TTS/tts/layers/generic/wavenet.py b/TTS/tts/layers/generic/wavenet.py index aeb45c7b..613ad19d 100644 --- a/TTS/tts/layers/generic/wavenet.py +++ b/TTS/tts/layers/generic/wavenet.py @@ -67,9 +67,14 @@ class WN(torch.nn.Module): for i in range(num_layers): dilation = dilation_rate**i padding = int((kernel_size * dilation - dilation) / 2) - in_layer = torch.nn.Conv1d( - hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding - ) + if i == 0: + in_layer = torch.nn.Conv1d( + in_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding + ) + else: + in_layer = torch.nn.Conv1d( + hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilation, padding=padding + ) in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") self.in_layers.append(in_layer) diff --git a/TTS/tts/layers/glow_tts/glow.py b/TTS/tts/layers/glow_tts/glow.py index ff1b99e8..3b745018 100644 --- a/TTS/tts/layers/glow_tts/glow.py +++ b/TTS/tts/layers/glow_tts/glow.py @@ -197,7 +197,7 @@ class CouplingBlock(nn.Module): end.bias.data.zero_() self.end = end # coupling layers - self.wn = WN(in_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels, dropout_p) + self.wn = WN(hidden_channels, hidden_channels, kernel_size, dilation_rate, num_layers, c_in_channels, dropout_p) def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs): # pylint: disable=unused-argument """