From 7cecd2fb2e2630204ed760b7de6d8ee2f2a53ed6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Wed, 7 Apr 2021 19:19:03 +0200 Subject: [PATCH] add hifigan D --- TTS/vocoder/models/hifigan_discriminator.py | 212 ++++++++++++++++++ .../models/multi_period_discriminator.py | 77 ------- 2 files changed, 212 insertions(+), 77 deletions(-) create mode 100644 TTS/vocoder/models/hifigan_discriminator.py delete mode 100644 TTS/vocoder/models/multi_period_discriminator.py diff --git a/TTS/vocoder/models/hifigan_discriminator.py b/TTS/vocoder/models/hifigan_discriminator.py new file mode 100644 index 00000000..5e508a99 --- /dev/null +++ b/TTS/vocoder/models/hifigan_discriminator.py @@ -0,0 +1,212 @@ +# adopted from https://github.com/jik876/hifi-gan/blob/master/models.py + +import torch +from torch import nn +from torch.nn import functional as F + + +LRELU_SLOPE = 0.1 + + +class DiscriminatorP(torch.nn.Module): + """HiFiGAN Periodic Discriminator + + Takes every Pth value from the input waveform and applied a stack of convoluations. + + Note: + if `period` is 2 + `waveform = [1, 2, 3, 4, 5, 6 ...] --> [1, 3, 5 ... ] --> convs -> score, feat` + + Args: + x (Tensor): input waveform. + + Returns: + [Tensor]: discriminator scores per sample in the batch. + [List[Tensor]]: list of features from each convolutional layer. + + Shapes: + x: [B, 1, T] + """ + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super().__init__() + self.period = period + get_padding = lambda k, d: int((k*d - d)/2) + norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm + self.convs = nn.ModuleList([ + norm_f(nn.Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(nn.Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(nn.Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(nn.Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + [Tensor]: discriminator scores per sample in the batch. + [List[Tensor]]: list of features from each convolutional layer. + + Shapes: + x: [B, 1, T] + """ + feat = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + feat.append(x) + x = self.conv_post(x) + feat.append(x) + x = torch.flatten(x, 1, -1) + + return x, feat + + +class MultiPeriodDiscriminator(torch.nn.Module): + """HiFiGAN Multi-Period Discriminator (MPD) + Wrapper for the `PeriodDiscriminator` to apply it in different periods. + Periods are suggested to be prime numbers to reduce the overlap between each discriminator. + """ + def __init__(self): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ]) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + [List[Tensor]]: list of scores from each discriminator. + [List[List[Tensor]]]: list of list of features from each discriminator's each convolutional layer. + + Shapes: + x: [B, 1, T] + """ + scores = [] + feats = [] + for _, d in enumerate(self.discriminators): + score, feat = d(x) + scores.append(score) + feats.append(feat) + return scores, feats + + +class DiscriminatorS(torch.nn.Module): + """HiFiGAN Scale Discriminator. + It is similar to `MelganDiscriminator` but with a specific architecture explained in the paper. + + Args: + use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm. + + """ + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm + self.convs = nn.ModuleList([ + norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)), + norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + Tensor: discriminator scores. + List[Tensor]: list of features from the convolutiona layers. + """ + feat = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + feat.append(x) + x = self.conv_post(x) + feat.append(x) + x = torch.flatten(x, 1, -1) + return x, feat + + +class MultiScaleDiscriminator(torch.nn.Module): + """HiFiGAN Multi-Scale Discriminator. + It is similar to `MultiScaleMelganDiscriminator` but specially tailored for HiFiGAN as in the paper. + """ + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList([ + nn.AvgPool1d(4, 2, padding=2), + nn.AvgPool1d(4, 2, padding=2) + ]) + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + List[Tensor]: discriminator scores. + List[List[Tensor]]: list of list of features from each layers of each discriminator. + """ + scores = [] + feats = [] + for i, d in enumerate(self.discriminators): + if i != 0: + x = self.meanpools[i-1](x) + score, feat = d(x) + scores.append(score) + feats.append(feat) + return scores, feats + + +class HifiganDiscriminator(nn.Module): + """HiFiGAN discriminator wrapping MPD and MSD. + """ + def __init__(self): + super().__init__() + self.mpd = MultiPeriodDiscriminator() + self.msd = MultiScaleDiscriminator() + + def forward(self, x): + """ + Args: + x (Tensor): input waveform. + + Returns: + List[Tensor]: discriminator scores. + List[List[Tensor]]: list of list of features from each layers of each discriminator. + """ + scores, feats = self.msd(x) + scores_, feats_ = self.mpd(x) + scores += scores_ + feats += feats_ + return scores, feats diff --git a/TTS/vocoder/models/multi_period_discriminator.py b/TTS/vocoder/models/multi_period_discriminator.py deleted file mode 100644 index 8f821a87..00000000 --- a/TTS/vocoder/models/multi_period_discriminator.py +++ /dev/null @@ -1,77 +0,0 @@ -from torch import nn -import torch.nn.functional as F -from TTS.vocoder.models.melgan_multiscale_discriminator import MelganMultiscaleDiscriminator - - -class PeriodDiscriminator(nn.Module): - def __init__(self, period): - super(PeriodDiscriminator, self).__init__() - layer = [] - self.period = period - inp = 1 - for l in range(4): - out = int(2**(5 + l + 1)) - layer += [ - nn.utils.weight_norm( - nn.Conv2d(inp, out, kernel_size=(5, 1), stride=(3, 1))), - nn.LeakyReLU(0.2) - ] - inp = out - self.layer = nn.Sequential(*layer) - self.output = nn.Sequential( - nn.utils.weight_norm(nn.Conv2d(out, 1024, kernel_size=(5, 1))), - nn.LeakyReLU(0.2), - nn.utils.weight_norm(nn.Conv2d(1024, 1, kernel_size=(3, 1)))) - - def forward(self, x): - batch_size = x.shape[0] - pad = self.period - (x.shape[-1] % self.period) - x = F.pad(x, (0, pad)) - y = x.view(batch_size, -1, self.period).contiguous() - y = y.unsqueeze(1) - out1 = self.layer(y) - return self.output(out1) - - -class HifiDiscriminator(nn.Module): - def __init__(self, - periods=[2, 3, 5, 7, 11], - in_channels=1, - out_channels=1, - num_scales=3, - kernel_sizes=(5, 3), - base_channels=64, - max_channels=1024, - downsample_factors=(2, 2, 4, 4), - pooling_kernel_size=4, - pooling_stride=2, - pooling_padding=1): - super().__init__() - self.discriminators = nn.ModuleList([ - PeriodDiscriminator(periods[0]), - PeriodDiscriminator(periods[1]), - PeriodDiscriminator(periods[2]), - PeriodDiscriminator(periods[3]), - PeriodDiscriminator(periods[4]) - ]) - - self.msd = MelganMultiscaleDiscriminator( - in_channels=in_channels, - out_channels=out_channels, - num_scales=num_scales, - kernel_sizes=kernel_sizes, - base_channels=base_channels, - max_channels=max_channels, - downsample_factors=downsample_factors, - pooling_kernel_size=pooling_kernel_size, - pooling_stride=pooling_stride, - pooling_padding=pooling_padding, - groups_denominator=32, - max_groups=16) - - def forward(self, x): - scores, feats = self.msd(x) - for key, disc in enumerate(self.discriminators): - score = disc(x) - scores.append(score) - return scores, feats