mirror of https://github.com/coqui-ai/TTS.git
Linter fixes and docstrings for HiFiGAN
parent
bd7a1c177b
commit
d95b1458e8
|
@ -12,6 +12,19 @@ def get_padding(k, d):
|
||||||
|
|
||||||
|
|
||||||
class ResBlock1(torch.nn.Module):
|
class ResBlock1(torch.nn.Module):
|
||||||
|
"""Residual Block Type 1. It has 3 convolutional layers in each convolutiona block.
|
||||||
|
|
||||||
|
Network:
|
||||||
|
|
||||||
|
x -> lrelu -> conv1_1 -> conv1_2 -> conv1_3 -> z -> lrelu -> conv2_1 -> conv2_2 -> conv2_3 -> o -> + -> o
|
||||||
|
|--------------------------------------------------------------------------------------------------|
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels (int): number of hidden channels for the convolutional layers.
|
||||||
|
kernel_size (int): size of the convolution filter in each layer.
|
||||||
|
dilations (list): list of dilation value for each conv layer in a block.
|
||||||
|
"""
|
||||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.convs1 = nn.ModuleList([
|
self.convs1 = nn.ModuleList([
|
||||||
|
@ -63,13 +76,21 @@ class ResBlock1(torch.nn.Module):
|
||||||
])
|
])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x (Tensor): input tensor.
|
||||||
|
Returns:
|
||||||
|
Tensor: output tensor.
|
||||||
|
Shapes:
|
||||||
|
x: [B, C, T]
|
||||||
|
"""
|
||||||
for c1, c2 in zip(self.convs1, self.convs2):
|
for c1, c2 in zip(self.convs1, self.convs2):
|
||||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
o = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
xt = c1(xt)
|
o = c1(o)
|
||||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
o = F.leaky_relu(o, LRELU_SLOPE)
|
||||||
xt = c2(xt)
|
o = c2(o)
|
||||||
x = xt + x
|
o = o + x
|
||||||
return x
|
return o
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
for l in self.convs1:
|
for l in self.convs1:
|
||||||
|
@ -79,6 +100,19 @@ class ResBlock1(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class ResBlock2(torch.nn.Module):
|
class ResBlock2(torch.nn.Module):
|
||||||
|
"""Residual Block Type 1. It has 3 convolutional layers in each convolutiona block.
|
||||||
|
|
||||||
|
Network:
|
||||||
|
|
||||||
|
x -> lrelu -> conv1-> -> z -> lrelu -> conv2-> o -> + -> o
|
||||||
|
|---------------------------------------------------|
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
channels (int): number of hidden channels for the convolutional layers.
|
||||||
|
kernel_size (int): size of the convolution filter in each layer.
|
||||||
|
dilations (list): list of dilation value for each conv layer in a block.
|
||||||
|
"""
|
||||||
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.convs = nn.ModuleList([
|
self.convs = nn.ModuleList([
|
||||||
|
@ -100,10 +134,10 @@ class ResBlock2(torch.nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
for c in self.convs:
|
for c in self.convs:
|
||||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
o = F.leaky_relu(x, LRELU_SLOPE)
|
||||||
xt = c(xt)
|
o = c(o)
|
||||||
x = xt + x
|
o = o + x
|
||||||
return x
|
return o
|
||||||
|
|
||||||
def remove_weight_norm(self):
|
def remove_weight_norm(self):
|
||||||
for l in self.convs:
|
for l in self.convs:
|
||||||
|
@ -111,17 +145,38 @@ class ResBlock2(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class HifiganGenerator(torch.nn.Module):
|
class HifiganGenerator(torch.nn.Module):
|
||||||
def __init__(self, in_channels, out_channels, resblock_type, resblock_dilation_sizes,
|
def __init__(self, in_channels, out_channels, resblock_type,
|
||||||
resblock_kernel_sizes, upsample_kernel_sizes,
|
resblock_dilation_sizes, resblock_kernel_sizes,
|
||||||
upsample_initial_channel, upsample_factors):
|
upsample_kernel_sizes, upsample_initial_channel,
|
||||||
|
upsample_factors, inference_padding=5):
|
||||||
|
r"""HiFiGAN Generator with Multi-Receptive Field Fusion (MRF)
|
||||||
|
|
||||||
|
Network:
|
||||||
|
x -> lrelu -> upsampling_layer -> resblock1_k1x1 -> z1 -> + -> z_sum / #resblocks -> lrelu -> conv_post_7x1 -> tanh -> o
|
||||||
|
.. -> zI ---|
|
||||||
|
resblockN_kNx1 -> zN ---'
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (int): number of input tensor channels.
|
||||||
|
out_channels (int): number of output tensor channels.
|
||||||
|
resblock_type (str): type of the `ResBlock`. '1' or '2'.
|
||||||
|
resblock_dilation_sizes (List[List[int]]): list of dilation values in each layer of a `ResBlock`.
|
||||||
|
resblock_kernel_sizes (List[int]): list of kernel sizes for each `ResBlock`.
|
||||||
|
upsample_kernel_sizes (List[int]): list of kernel sizes for each transposed convolution.
|
||||||
|
upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2
|
||||||
|
for each consecutive upsampling layer.
|
||||||
|
upsample_factors (List[int]): upsampling factors (stride) for each upsampling layer.
|
||||||
|
inference_padding (int): constant padding applied to the input at inference time. Defaults to 5.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.inference_padding = 5
|
self.inference_padding = inference_padding
|
||||||
self.num_kernels = len(resblock_kernel_sizes)
|
self.num_kernels = len(resblock_kernel_sizes)
|
||||||
self.num_upsamples = len(upsample_factors)
|
self.num_upsamples = len(upsample_factors)
|
||||||
|
# initial upsampling layers
|
||||||
self.conv_pre = weight_norm(
|
self.conv_pre = weight_norm(
|
||||||
Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
|
Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3))
|
||||||
resblock = ResBlock1 if resblock_type == '1' else ResBlock2
|
resblock = ResBlock1 if resblock_type == '1' else ResBlock2
|
||||||
|
# upsampling layers
|
||||||
self.ups = nn.ModuleList()
|
self.ups = nn.ModuleList()
|
||||||
for i, (u, k) in enumerate(zip(upsample_factors,
|
for i, (u, k) in enumerate(zip(upsample_factors,
|
||||||
upsample_kernel_sizes)):
|
upsample_kernel_sizes)):
|
||||||
|
@ -132,32 +187,32 @@ class HifiganGenerator(torch.nn.Module):
|
||||||
k,
|
k,
|
||||||
u,
|
u,
|
||||||
padding=(k - u) // 2)))
|
padding=(k - u) // 2)))
|
||||||
|
# MRF blocks
|
||||||
self.resblocks = nn.ModuleList()
|
self.resblocks = nn.ModuleList()
|
||||||
for i in range(len(self.ups)):
|
for i in range(len(self.ups)):
|
||||||
ch = upsample_initial_channel // (2**(i + 1))
|
ch = upsample_initial_channel // (2**(i + 1))
|
||||||
for j, (k, d) in enumerate(
|
for j, (k, d) in enumerate(
|
||||||
zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||||
self.resblocks.append(resblock(ch, k, d))
|
self.resblocks.append(resblock(ch, k, d))
|
||||||
|
# post convolution layer
|
||||||
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3))
|
self.conv_post = weight_norm(Conv1d(ch, out_channels, 7, 1, padding=3))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv_pre(x)
|
o = self.conv_pre(x)
|
||||||
for i in range(self.num_upsamples):
|
for i in range(self.num_upsamples):
|
||||||
x = F.leaky_relu(x, LRELU_SLOPE)
|
o = F.leaky_relu(o, LRELU_SLOPE)
|
||||||
x = self.ups[i](x)
|
o = self.ups[i](o)
|
||||||
xs = None
|
z_sum = None
|
||||||
for j in range(self.num_kernels):
|
for j in range(self.num_kernels):
|
||||||
if xs is None:
|
if z_sum is None:
|
||||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
z_sum = self.resblocks[i * self.num_kernels + j](o)
|
||||||
else:
|
else:
|
||||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
z_sum += self.resblocks[i * self.num_kernels + j](o)
|
||||||
x = xs / self.num_kernels
|
o = z_sum / self.num_kernels
|
||||||
x = F.leaky_relu(x)
|
o = F.leaky_relu(o)
|
||||||
x = self.conv_post(x)
|
o = self.conv_post(o)
|
||||||
x = torch.tanh(x)
|
o = torch.tanh(o)
|
||||||
return x
|
return o
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self, c):
|
def inference(self, c):
|
||||||
|
|
Loading…
Reference in New Issue