diff --git a/TTS/vocoder/models/hifigan_generator.py b/TTS/vocoder/models/hifigan_generator.py index 1da726c6..de8a3d87 100644 --- a/TTS/vocoder/models/hifigan_generator.py +++ b/TTS/vocoder/models/hifigan_generator.py @@ -12,6 +12,19 @@ def get_padding(k, d): class ResBlock1(torch.nn.Module): + """Residual Block Type 1. It has 3 convolutional layers in each convolutiona block. + + Network: + + x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o + |--------------------------------------------------------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): super().__init__() self.convs1 = nn.ModuleList([ @@ -63,13 +76,21 @@ class ResBlock1(torch.nn.Module): ]) def forward(self, x): + """ + Args: + x (Tensor): input tensor. + Returns: + Tensor: output tensor. + Shapes: + x: [B, C, T] + """ for c1, c2 in zip(self.convs1, self.convs2): - xt = F.leaky_relu(x, LRELU_SLOPE) - xt = c1(xt) - xt = F.leaky_relu(xt, LRELU_SLOPE) - xt = c2(xt) - x = xt + x - return x + o = F.leaky_relu(x, LRELU_SLOPE) + o = c1(o) + o = F.leaky_relu(o, LRELU_SLOPE) + o = c2(o) + o = o + x + return o def remove_weight_norm(self): for l in self.convs1: @@ -79,6 +100,19 @@ class ResBlock1(torch.nn.Module): class ResBlock2(torch.nn.Module): + """Residual Block Type 1. It has 3 convolutional layers in each convolutiona block. + + Network: + + x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o + |---------------------------------------------------| + + + Args: + channels (int): number of hidden channels for the convolutional layers. + kernel_size (int): size of the convolution filter in each layer. + dilations (list): list of dilation value for each conv layer in a block. + """ def __init__(self, channels, kernel_size=3, dilation=(1, 3)): super().__init__() self.convs = nn.ModuleList([ @@ -100,10 +134,10 @@ class ResBlock2(torch.nn.Module): def forward(self, x): for c in self.convs: - xt = F.leaky_relu(x, LRELU_SLOPE) - xt = c(xt) - x = xt + x - return x + o = F.leaky_relu(x, LRELU_SLOPE) + o = c(o) + o = o + x + return o def remove_weight_norm(self): for l in self.convs: @@ -111,17 +145,38 @@ class ResBlock2(torch.nn.Module): class HifiganGenerator(torch.nn.Module): - def __init__(self, in_channels, out_channels, resblock_type, resblock_dilation_sizes, - resblock_kernel_sizes, upsample_kernel_sizes, - upsample_initial_channel, upsample_factors): + def __init__(self, in_channels, out_channels, resblock_type, + resblock_dilation_sizes, resblock_kernel_sizes, + upsample_kernel_sizes, upsample_initial_channel, + upsample_factors, inference_padding=5): + r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF) + + Network: + x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o + .. -> zI ---| + resblockN_kNx1 -> zN ---' + + Args: + in_channels (int): number of input tensor channels. + out_channels (int): number of output tensor channels. + resblock_type (str): type of the `ResBlock`. '1' or '2'. + resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`. + resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`. + upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution. + upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2 + for each consecutive upsampling layer. + upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer. + inference_padding (int): constant padding applied to the input at inference time. Defaults to 5. + """ super().__init__() - self.inference_padding = 5 + self.inference_padding = inference_padding self.num_kernels = len(resblock_kernel_sizes) self.num_upsamples = len(upsample_factors) + # initial upsampling layers self.conv_pre = weight_norm( Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)) resblock = ResBlock1 if resblock_type == '1' else ResBlock2 - + # upsampling layers self.ups = nn.ModuleList() for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)): @@ -132,32 +187,32 @@ class HifiganGenerator(torch.nn.Module): k, u, padding=(k - u) // 2))) - + # MRF blocks self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = upsample_initial_channel // (2**(i + 1)) for j, (k, d) in enumerate( zip(resblock_kernel_sizes, resblock_dilation_sizes)): self.resblocks.append(resblock(ch, k, d)) - + # post convolution layer self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3)) def forward(self, x): - x = self.conv_pre(x) + o = self.conv_pre(x) for i in range(self.num_upsamples): - x = F.leaky_relu(x, LRELU_SLOPE) - x = self.ups[i](x) - xs = None + o = F.leaky_relu(o, LRELU_SLOPE) + o = self.ups[i](o) + z_sum = None for j in range(self.num_kernels): - if xs is None: - xs = self.resblocks[i * self.num_kernels + j](x) + if z_sum is None: + z_sum = self.resblocks[i * self.num_kernels + j](o) else: - xs += self.resblocks[i * self.num_kernels + j](x) - x = xs / self.num_kernels - x = F.leaky_relu(x) - x = self.conv_post(x) - x = torch.tanh(x) - return x + z_sum += self.resblocks[i * self.num_kernels + j](o) + o = z_sum / self.num_kernels + o = F.leaky_relu(o) + o = self.conv_post(o) + o = torch.tanh(o) + return o @torch.no_grad() def inference(self, c):