From 98a7271ce8e17ebc9533b0c9e7af8e05b70b4fb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Mon, 26 Jul 2021 22:50:34 +0200 Subject: [PATCH] Refactor FastPitchv2 --- TTS/tts/models/fast_pitch.py | 664 ++++++++++------------------------- TTS/utils/audio.py | 42 +-- 2 files changed, 210 insertions(+), 496 deletions(-) diff --git a/TTS/tts/models/fast_pitch.py b/TTS/tts/models/fast_pitch.py index 6f9cee36..c218535e 100644 --- a/TTS/tts/models/fast_pitch.py +++ b/TTS/tts/models/fast_pitch.py @@ -1,23 +1,23 @@ from dataclasses import dataclass, field import torch -import torch.nn as nn import torch.nn.functional as F from coqpit import Coqpit -from matplotlib.pyplot import plot +from torch import nn +from TTS.tts.layers.feed_forward.decoder import Decoder +from TTS.tts.layers.feed_forward.encoder import Encoder +from TTS.tts.layers.generic.pos_encoding import PositionalEncoding +from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask +from TTS.tts.utils.measures import alignment_diagonal_score from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor -# pylint: disable=dangerous-default-value - class AlignmentEncoder(torch.nn.Module): - """Module for alignment text and mel spectrogram.""" - def __init__( self, in_query_channels=80, @@ -31,32 +31,35 @@ class AlignmentEncoder(torch.nn.Module): self.log_softmax = torch.nn.LogSoftmax(dim=3) self.key_proj = nn.Sequential( - ConvNorm( - in_key_channels, in_key_channels * 2, kernel_size=3, bias=True, w_init_gain="relu", batch_norm=False + nn.Conv1d( + in_key_channels, + in_key_channels * 2, + kernel_size=3, + padding=1, + bias=True, ), torch.nn.ReLU(), - ConvNorm(in_key_channels * 2, attn_channels, kernel_size=1, bias=True, batch_norm=False), + nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True), ) self.query_proj = nn.Sequential( - ConvNorm( - in_query_channels, in_query_channels * 2, kernel_size=3, bias=True, w_init_gain="relu", batch_norm=False + nn.Conv1d( + in_query_channels, + in_query_channels * 2, + kernel_size=3, + padding=1, + bias=True, ), torch.nn.ReLU(), - ConvNorm(in_query_channels * 2, in_query_channels, kernel_size=1, bias=True, batch_norm=False), + nn.Conv1d(in_query_channels * 2, in_query_channels, kernel_size=1, padding=0, bias=True), torch.nn.ReLU(), - ConvNorm(in_query_channels, attn_channels, kernel_size=1, bias=True, batch_norm=False), + nn.Conv1d(in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True), ) def forward( self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None ): """Forward pass of the aligner encoder. - Args: - queries (torch.tensor): query tensor. - keys (torch.tensor): key tensor. - mask (torch.tensor): uint8 binary mask for variable length entries (should be in the T2 domain). - attn_prior (torch.tensor): prior for attention matrix. Shapes: - queries: :math:`(B, C, T_de)` - keys: :math:`(B, C_emb, T_en)` @@ -84,365 +87,30 @@ class AlignmentEncoder(torch.nn.Module): return attn, attn_logprob -def mask_from_lens(lens, max_len: int = None): - if max_len is None: - max_len = lens.max() - ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype) - mask = torch.lt(ids, lens.unsqueeze(1)) - return mask - - -class LinearNorm(torch.nn.Module): - def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"): - super(LinearNorm, self).__init__() - self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) - - torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) - - def forward(self, x): - return self.linear_layer(x) - - -class ConvNorm(torch.nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size=1, - stride=1, - padding=None, - dilation=1, - bias=True, - w_init_gain="linear", - batch_norm=False, - ): - super(ConvNorm, self).__init__() - if padding is None: - assert kernel_size % 2 == 1 - padding = int(dilation * (kernel_size - 1) / 2) - - self.conv = torch.nn.Conv1d( - in_channels, - out_channels, - kernel_size=kernel_size, - stride=stride, - padding=padding, - dilation=dilation, - bias=bias, - ) - self.norm = torch.nn.BatchNorm1D(out_channels) if batch_norm else None - - torch.nn.init.xavier_uniform_(self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) - - def forward(self, signal): - if self.norm is None: - return self.conv(signal) - else: - return self.norm(self.conv(signal)) - - -class ConvReLUNorm(torch.nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=1, dropout=0.0): - super(ConvReLUNorm, self).__init__() - self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=(kernel_size // 2)) - self.norm = torch.nn.LayerNorm(out_channels) - self.dropout = torch.nn.Dropout(dropout) - - def forward(self, signal): - out = F.relu(self.conv(signal)) - out = self.norm(out.transpose(1, 2)).transpose(1, 2).to(signal.dtype) - return self.dropout(out) - - -class PositionalEmbedding(nn.Module): - def __init__(self, demb): - super(PositionalEmbedding, self).__init__() - self.demb = demb - inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) - self.register_buffer("inv_freq", inv_freq) - - def forward(self, pos_seq, bsz=None): - sinusoid_inp = torch.matmul(torch.unsqueeze(pos_seq, -1), torch.unsqueeze(self.inv_freq, 0)) - pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1) - if bsz is not None: - return pos_emb[None, :, :].expand(bsz, -1, -1) - else: - return pos_emb[None, :, :] - - -class PositionwiseConvFF(nn.Module): - def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False): - super(PositionwiseConvFF, self).__init__() - - self.d_model = d_model - self.d_inner = d_inner - self.dropout = dropout - - self.CoreNet = nn.Sequential( - nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)), - nn.ReLU(), - # nn.Dropout(dropout), # worse convergence - nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)), - nn.Dropout(dropout), - ) - self.layer_norm = nn.LayerNorm(d_model) - self.pre_lnorm = pre_lnorm - - def forward(self, inp): - return self._forward(inp) - - def _forward(self, inp): - if self.pre_lnorm: - # layer normalization + positionwise feed-forward - core_out = inp.transpose(1, 2) - core_out = self.CoreNet(self.layer_norm(core_out).to(inp.dtype)) - core_out = core_out.transpose(1, 2) - - # residual connection - output = core_out + inp - else: - # positionwise feed-forward - core_out = inp.transpose(1, 2) - core_out = self.CoreNet(core_out) - core_out = core_out.transpose(1, 2) - - # residual connection + layer normalization - output = self.layer_norm(inp + core_out).to(inp.dtype) - - return output - - -class MultiHeadAttn(nn.Module): - def __init__(self, num_heads, d_model, hidden_channels_head, dropout, dropout_attn=0.1, pre_lnorm=False): - super(MultiHeadAttn, self).__init__() - - self.num_heads = num_heads - self.d_model = d_model - self.hidden_channels_head = hidden_channels_head - self.scale = 1 / (hidden_channels_head ** 0.5) - self.pre_lnorm = pre_lnorm - - self.qkv_net = nn.Linear(d_model, 3 * num_heads * hidden_channels_head) - self.drop = nn.Dropout(dropout) - self.dropout_attn = nn.Dropout(dropout_attn) - self.o_net = nn.Linear(num_heads * hidden_channels_head, d_model, bias=False) - self.layer_norm = nn.LayerNorm(d_model) - - def forward(self, inp, attn_mask=None): - return self._forward(inp, attn_mask) - - def _forward(self, inp, attn_mask=None): - residual = inp - - if self.pre_lnorm: - # layer normalization - inp = self.layer_norm(inp) - - num_heads, hidden_channels_head = self.num_heads, self.hidden_channels_head - - head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2) - head_q = head_q.view(inp.size(0), inp.size(1), num_heads, hidden_channels_head) - head_k = head_k.view(inp.size(0), inp.size(1), num_heads, hidden_channels_head) - head_v = head_v.view(inp.size(0), inp.size(1), num_heads, hidden_channels_head) - - q = head_q.permute(0, 2, 1, 3).reshape(-1, inp.size(1), hidden_channels_head) - k = head_k.permute(0, 2, 1, 3).reshape(-1, inp.size(1), hidden_channels_head) - v = head_v.permute(0, 2, 1, 3).reshape(-1, inp.size(1), hidden_channels_head) - - attn_score = torch.bmm(q, k.transpose(1, 2)) - attn_score.mul_(self.scale) - - if attn_mask is not None: - attn_mask = attn_mask.unsqueeze(1).to(attn_score.dtype) - attn_mask = attn_mask.repeat(num_heads, attn_mask.size(2), 1) - attn_score.masked_fill_(attn_mask.to(torch.bool), -float("inf")) - - attn_prob = F.softmax(attn_score, dim=2) - attn_prob = self.dropout_attn(attn_prob) - attn_vec = torch.bmm(attn_prob, v) - - attn_vec = attn_vec.view(num_heads, inp.size(0), inp.size(1), hidden_channels_head) - attn_vec = ( - attn_vec.permute(1, 2, 0, 3).contiguous().view(inp.size(0), inp.size(1), num_heads * hidden_channels_head) - ) - - # linear projection - attn_out = self.o_net(attn_vec) - attn_out = self.drop(attn_out) - - if self.pre_lnorm: - # residual connection - output = residual + attn_out - else: - # residual connection + layer normalization - output = self.layer_norm(residual + attn_out) - - output = output.to(attn_out.dtype) - - return output - - -class TransformerLayer(nn.Module): - def __init__( - self, num_heads, hidden_channels, hidden_channels_head, hidden_channels_ffn, kernel_size, dropout, **kwargs - ): - super(TransformerLayer, self).__init__() - - self.dec_attn = MultiHeadAttn(num_heads, hidden_channels, hidden_channels_head, dropout, **kwargs) - self.pos_ff = PositionwiseConvFF( - hidden_channels, hidden_channels_ffn, kernel_size, dropout, pre_lnorm=kwargs.get("pre_lnorm") - ) - - def forward(self, dec_inp, mask=None): - output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2)) - output *= mask - output = self.pos_ff(output) - output *= mask - return output - - -class FFTransformer(nn.Module): - def __init__( - self, - num_layers, - num_heads, - hidden_channels, - hidden_channels_head, - hidden_channels_ffn, - kernel_size, - dropout, - dropout_attn, - dropemb=0.0, - pre_lnorm=False, - ): - super(FFTransformer, self).__init__() - self.hidden_channels = hidden_channels - self.num_heads = num_heads - self.hidden_channels_head = hidden_channels_head - - self.pos_emb = PositionalEmbedding(self.hidden_channels) - self.drop = nn.Dropout(dropemb) - self.layers = nn.ModuleList() - - for _ in range(num_layers): - self.layers.append( - TransformerLayer( - num_heads, - hidden_channels, - hidden_channels_head, - hidden_channels_ffn, - kernel_size, - dropout, - dropout_attn=dropout_attn, - pre_lnorm=pre_lnorm, - ) - ) - - def forward(self, x, x_lengths, conditioning=0): - mask = mask_from_lens(x_lengths).unsqueeze(2) - - pos_seq = torch.arange(x.size(1), device=x.device).to(x.dtype) - pos_emb = self.pos_emb(pos_seq) * mask - - if conditioning is None: - conditioning = 0 - - out = self.drop(x + pos_emb + conditioning) - - for layer in self.layers: - out = layer(out, mask=mask) - - # out = self.drop(out) - return out, mask - - -def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None): - """If target=None, then predicted durations are applied""" - dtype = enc_out.dtype - reps = durations.float() / pace - reps = (reps + 0.5).long() - dec_lens = reps.sum(dim=1) - - max_len = dec_lens.max() - reps_cumsum = torch.cumsum(F.pad(reps, (1, 0, 0, 0), value=0.0), dim=1)[:, None, :] - reps_cumsum = reps_cumsum.to(dtype) - - range_ = torch.arange(max_len).to(enc_out.device)[None, :, None] - mult = (reps_cumsum[:, :, :-1] <= range_) & (reps_cumsum[:, :, 1:] > range_) - mult = mult.to(dtype) - en_ex = torch.matmul(mult, enc_out) - - if mel_max_len: - en_ex = en_ex[:, :mel_max_len] - dec_lens = torch.clamp_max(dec_lens, mel_max_len) - return en_ex, dec_lens - - -class TemporalPredictor(nn.Module): - """Predicts a single float per each temporal location""" - - def __init__(self, input_size, filter_size, kernel_size, dropout, num_layers=2): - super(TemporalPredictor, self).__init__() - - self.layers = nn.Sequential( - *[ - ConvReLUNorm( - input_size if i == 0 else filter_size, filter_size, kernel_size=kernel_size, dropout=dropout - ) - for i in range(num_layers) - ] - ) - self.fc = nn.Linear(filter_size, 1, bias=True) - - def forward(self, enc_out, enc_out_mask): - out = enc_out * enc_out_mask - out = self.layers(out.transpose(1, 2)).transpose(1, 2) - out = self.fc(out) * enc_out_mask - return out.squeeze(-1) - - @dataclass class FastPitchArgs(Coqpit): - num_chars: int = 100 + num_chars: int = None out_channels: int = 80 - hidden_channels: int = 384 + hidden_channels: int = 256 num_speakers: int = 0 duration_predictor_hidden_channels: int = 256 duration_predictor_dropout: float = 0.1 duration_predictor_kernel_size: int = 3 duration_predictor_dropout_p: float = 0.1 - duration_predictor_num_layers: int = 2 pitch_predictor_hidden_channels: int = 256 pitch_predictor_dropout: float = 0.1 pitch_predictor_kernel_size: int = 3 pitch_predictor_dropout_p: float = 0.1 pitch_embedding_kernel_size: int = 3 - pitch_predictor_num_layers: int = 2 positional_encoding: bool = True length_scale: int = 1 encoder_type: str = "fftransformer" encoder_params: dict = field( - default_factory=lambda: { - "hidden_channels_head": 64, - "hidden_channels_ffn": 1536, - "num_heads": 1, - "num_layers": 6, - "kernel_size": 3, - "dropout": 0.1, - "dropout_attn": 0.1, - } + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} ) decoder_type: str = "fftransformer" decoder_params: dict = field( - default_factory=lambda: { - "hidden_channels_head": 64, - "hidden_channels_ffn": 1536, - "num_heads": 1, - "num_layers": 6, - "kernel_size": 3, - "dropout": 0.1, - "dropout_attn": 0.1, - } + default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1} ) use_d_vector: bool = False d_vector_dim: int = 0 @@ -477,7 +145,9 @@ class FastPitch(BaseTTS): >>> model = FastPitch(config) """ + # pylint: disable=dangerous-default-value def __init__(self, config: Coqpit): + super().__init__() if "characters" in config: @@ -496,44 +166,54 @@ class FastPitch(BaseTTS): self.length_scale = float(args.length_scale) if isinstance(args.length_scale, int) else args.length_scale - self.encoder = FFTransformer( - hidden_channels=args.hidden_channels, - **args.encoder_params, + self.emb = nn.Embedding(config.model_args.num_chars, config.model_args.hidden_channels) + + self.encoder = Encoder( + config.model_args.hidden_channels, + config.model_args.hidden_channels, + config.model_args.encoder_type, + config.model_args.encoder_params, + config.model_args.d_vector_dim, ) - # if n_speakers > 1: - # self.speaker_emb = nn.Embedding(n_speakers, symbols_embedding_dim) - # else: - # self.speaker_emb = None - # self.speaker_emb_weight = speaker_emb_weight - self.emb = nn.Embedding(args.num_chars, args.hidden_channels) + if config.model_args.positional_encoding: + self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels) - self.duration_predictor = TemporalPredictor( - args.hidden_channels, - filter_size=args.duration_predictor_hidden_channels, - kernel_size=args.duration_predictor_kernel_size, - dropout=args.duration_predictor_dropout_p, - num_layers=args.duration_predictor_num_layers, + self.decoder = Decoder( + config.model_args.out_channels, + config.model_args.hidden_channels, + config.model_args.decoder_type, + config.model_args.decoder_params, ) - self.decoder = FFTransformer(hidden_channels=args.hidden_channels, **args.decoder_params) + self.duration_predictor = DurationPredictor( + config.model_args.hidden_channels + config.model_args.d_vector_dim, + config.model_args.duration_predictor_hidden_channels, + config.model_args.duration_predictor_kernel_size, + config.model_args.duration_predictor_dropout_p, + ) - self.pitch_predictor = TemporalPredictor( - args.hidden_channels, - filter_size=args.pitch_predictor_hidden_channels, - kernel_size=args.pitch_predictor_kernel_size, - dropout=args.pitch_predictor_dropout_p, - num_layers=args.pitch_predictor_num_layers, + self.pitch_predictor = DurationPredictor( + config.model_args.hidden_channels + config.model_args.d_vector_dim, + config.model_args.pitch_predictor_hidden_channels, + config.model_args.pitch_predictor_kernel_size, + config.model_args.pitch_predictor_dropout_p, ) self.pitch_emb = nn.Conv1d( 1, - args.hidden_channels, - kernel_size=args.pitch_embedding_kernel_size, - padding=int((args.pitch_embedding_kernel_size - 1) / 2), + config.model_args.hidden_channels, + kernel_size=config.model_args.pitch_embedding_kernel_size, + padding=int((config.model_args.pitch_embedding_kernel_size - 1) / 2), ) - self.proj = nn.Linear(args.hidden_channels, args.out_channels, bias=True) + if config.model_args.num_speakers > 1 and not config.model_args.use_d_vector: + # speaker embedding layer + self.emb_g = nn.Embedding(config.model_args.num_speakers, config.model_args.d_vector_dim) + nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) + + if config.model_args.d_vector_dim > 0 and config.model_args.d_vector_dim != config.model_args.hidden_channels: + self.proj_g = nn.Conv1d(config.model_args.d_vector_dim, config.model_args.hidden_channels, 1) if args.use_aligner: self.aligner = AlignmentEncoder(args.out_channels, args.hidden_channels) @@ -555,64 +235,109 @@ class FastPitch(BaseTTS): """ attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype) - o_en_ex = torch.matmul(attn.transpose(1, 2), en) - return o_en_ex, attn.transpose(1, 2) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) + return o_en_ex, attn + + def format_durations(self, o_dr_log, x_mask): + o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale + o_dr[o_dr < 1] = 1.0 + o_dr = torch.round(o_dr) + return o_dr + + @staticmethod + def _concat_speaker_embedding(o_en, g): + g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en] + o_en = torch.cat([o_en, g_exp], 1) + return o_en + + def _sum_speaker_embedding(self, x, g): + # project g to decoder dim. + if hasattr(self, "proj_g"): + g = self.proj_g(g) + return x + g + + def _forward_encoder(self, x, x_lengths, g=None): + if hasattr(self, "emb_g"): + g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1] + + if g is not None: + g = g.unsqueeze(-1) + + # [B, T, C] + x_emb = self.emb(x) + + # compute sequence masks + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) + + # encoder pass + o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) + + # speaker conditioning for duration predictor + if g is not None: + o_en_dp = self._concat_speaker_embedding(o_en, g) + else: + o_en_dp = o_en + return o_en, o_en_dp, x_mask, g, x_emb + + def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) + # expand o_en with durations + o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) + # positional encoding + if hasattr(self, "pos_encoder"): + o_en_ex = self.pos_encoder(o_en_ex, y_mask) + # speaker embedding + if g is not None: + o_en_ex = self._sum_speaker_embedding(o_en_ex, g) + # decoder pass + o_de = self.decoder(o_en_ex, y_mask, g=g) + return o_de.transpose(1, 2), attn.transpose(1, 2) + + def _forward_pitch_predictor(self, o_en, x_mask, pitch=None, dr=None): + o_pitch = self.pitch_predictor(o_en, x_mask) + if pitch is not None: + avg_pitch = average_pitch(pitch, dr) + o_pitch_emb = self.pitch_emb(avg_pitch) + return o_pitch_emb, o_pitch, avg_pitch + o_pitch_emb = self.pitch_emb(o_pitch) + return o_pitch_emb, o_pitch + + def _forward_aligner(self, y, embedding, x_mask, y_mask): + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), embedding.transpose(1, 2), x_mask, None) + alignment_mas = maximum_path( + alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() + ) + o_alignment_dur = torch.sum(alignment_mas, -1) + return o_alignment_dur, alignment_logprob, alignment_mas def forward( self, x, x_lengths, y_lengths, y=None, dr=None, pitch=None, aux_input={"d_vectors": 0, "speaker_ids": None} - ): - speaker_embedding = aux_input["d_vectors"] if "d_vectors" in aux_input else 0 + ): # pylint: disable=unused-argument + """ + Shapes: + x: :math:`[B, T_max]` + x_lengths: :math:`[B]` + y_lengths: :math:`[B]` + y: :math:`[B, T_max2]` + dr: :math:`[B, T_max]` + g: :math:`[B, C]` + pitch: :math:`[B, 1, T]` + """ + g = aux_input["d_vectors"] if "d_vectors" in aux_input else None y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x.dtype) - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) - attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) - o_alignment_dur = None - alignment_logprob = None - alignment_mas = None - - # Calculate speaker embedding - # if self.speaker_emb is None: - # speaker_embedding = 0 - # else: - # speaker_embedding = self.speaker_emb(speaker).unsqueeze(1) - # speaker_embedding.mul_(self.speaker_emb_weight) - - # character embedding - embedding = self.emb(x) - - # Input FFT - o_en, mask_en = self.encoder(embedding, x_lengths, conditioning=speaker_embedding) - - # Embedded for predictors - o_en_dr, mask_en_dr = o_en, mask_en - - # Predict durations - o_dr_log = self.duration_predictor(o_en_dr, mask_en_dr) + o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_lengths, g) + if self.config.model_args.detach_duration_predictor: + o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) + else: + o_dr_log = self.duration_predictor(o_en_dp, x_mask) o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) - - # Aligner if self.use_aligner: - alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), embedding.transpose(1, 2), x_mask, None) - alignment_mas = maximum_path( - alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() - ) - o_alignment_dur = torch.log(1 + torch.sum(alignment_mas, -1)) - avg_pitch = average_pitch(pitch, o_alignment_dur) + o_alignment_dur, alignment_logprob, alignment_mas = self._forward_aligner(y, x_emb, x_mask, y_mask) dr = o_alignment_dur - - # TODO: move this to the dataset - avg_pitch = average_pitch(pitch, dr) - - # Predict pitch - o_pitch = self.pitch_predictor(o_en, mask_en).unsqueeze(1) - pitch_emb = self.pitch_emb(avg_pitch) - o_en = o_en + pitch_emb.transpose(1, 2) - - # len_regulated, dec_lens = regulate_len(dr, o_en, self.length_scale, mel_max_len) - o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) - - # Output FFT - o_de, _ = self.decoder(o_en_ex, y_lengths) - o_de = self.proj(o_de) + o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr) + o_en = o_en + o_pitch_emb + o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g) outputs = { "model_outputs": o_de, "durations_log": o_dr_log.squeeze(1), @@ -620,66 +345,55 @@ class FastPitch(BaseTTS): "pitch": o_pitch, "pitch_gt": avg_pitch, "alignments": attn, - "alignment_mas": alignment_mas, + "alignment_mas": alignment_mas.transpose(1, 2), "o_alignment_dur": o_alignment_dur, "alignment_logprob": alignment_logprob, } return outputs @torch.no_grad() - def inference(self, x, aux_input={"d_vectors": 0, "speaker_ids": None}): # pylint: disable=unused-argument - speaker_embedding = aux_input["d_vectors"] if "d_vectors" in aux_input else 0 - + def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument + """ + Shapes: + x: [B, T_max] + x_lengths: [B] + g: [B, C] + """ + g = aux_input["d_vectors"] if "d_vectors" in aux_input else None + x_lengths = torch.tensor(x.shape[1:2]).to(x.device) # input sequence should be greated than the max convolution size inference_padding = 5 if x.shape[1] < 13: inference_padding += 13 - x.shape[1] - # pad input to prevent dropping the last word x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0) - x_lengths = torch.tensor(x.shape[1:2]).to(x.device) - - # character embedding - embedding = self.emb(x) - - # if self.speaker_emb is None: - # else: - # speaker = torch.ones(inputs.size(0)).long().to(inputs.device) * speaker - # spk_emb = self.speaker_emb(speaker).unsqueeze(1) - # spk_emb.mul_(self.speaker_emb_weight) - - # Input FFT - o_en, mask_en = self.encoder(embedding, x_lengths, conditioning=speaker_embedding) - - # Predict durations - o_dr_log = self.duration_predictor(o_en, mask_en) - o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) - o_dr = o_dr * self.length_scale - - # Pitch over chars - o_pitch = self.pitch_predictor(o_en, mask_en).unsqueeze(1) - + o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + # duration predictor pass + o_dr_log = self.duration_predictor(o_en_dp, x_mask) + o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) + # pitch predictor pass + o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask) # if pitch_transform is not None: # if self.pitch_std[0] == 0.0: # # XXX LJSpeech-1.1 defaults # mean, std = 218.14, 67.24 # else: # mean, std = self.pitch_mean[0], self.pitch_std[0] - # pitch_pred = pitch_transform(pitch_pred, mask_en.sum(dim=(1, 2)), mean, std) - - o_pitch_emb = self.pitch_emb(o_pitch).transpose(1, 2) + # pitch_pred = pitch_transform(pitch_pred, enc_mask.sum(dim=(1,2)), mean, std) + # if pitch_tgt is None: + # pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2) + # else: + # pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2) o_en = o_en + o_pitch_emb - y_lengths = o_dr.sum(1) - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x.dtype) - - o_en_ex, attn = self.expand_encoder_outputs(o_en, o_dr, x_mask, y_mask) - o_de, _ = self.decoder(o_en_ex, y_lengths) - o_de = self.proj(o_de) - - outputs = {"model_outputs": o_de, "alignments": attn, "pitch": o_pitch, "durations_log": o_dr_log} + o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) + outputs = { + "model_outputs": o_de.transpose(1, 2), + "alignments": attn, + "pitch": o_pitch, + "durations_log": o_dr_log, + } return outputs def train_step(self, batch: dict, criterion: nn.Module): @@ -735,8 +449,8 @@ class FastPitch(BaseTTS): } if self.config.model_args.use_aligner and self.training: - alignment_mas = outputs["alignment_mas"] - figures["alignment_mas"] = plot_alignment(alignment_mas, ap, output_fig=False) + alignment_mas = outputs["alignment_mas"][0].data.cpu().numpy() + figures["alignment_mas"] = plot_alignment(alignment_mas, output_fig=False) # Sample audio train_audio = ap.inv_melspectrogram(pred_spec.T) diff --git a/TTS/utils/audio.py b/TTS/utils/audio.py index 96b9a1a1..6a74b3c8 100644 --- a/TTS/utils/audio.py +++ b/TTS/utils/audio.py @@ -647,29 +647,29 @@ class AudioProcessor(object): # frame_period=1000 * self.hop_length / self.sample_rate, # ) # f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate) - f0, _, _, _ = compute_yin( - x, - self.sample_rate, - self.win_length, - self.hop_length, - 65 if self.mel_fmin == 0 else self.mel_fmin, - self.mel_fmax, - ) - # import pyworld as pw - # f0, _ = pw.dio(x.astype(np.float64), self.sample_rate, - # frame_period=self.hop_length / self.sample_rate * 1000) - pad = int((self.win_length / self.hop_length) / 2) - f0 = [0.0] * pad + f0 + [0.0] * pad - f0 = np.array(f0, dtype=np.float32) - - # f01, _, _ = librosa.pyin( + # f0, _, _, _ = compute_yin( # x, - # fmin=65 if self.mel_fmin == 0 else self.mel_fmin, - # fmax=self.mel_fmax, - # frame_length=self.win_length, - # sr=self.sample_rate, - # fill_na=0.0, + # self.sample_rate, + # self.win_length, + # self.hop_length, + # 65 if self.mel_fmin == 0 else self.mel_fmin, + # self.mel_fmax, # ) + # # import pyworld as pw + # # f0, _ = pw.dio(x.astype(np.float64), self.sample_rate, + # # frame_period=self.hop_length / self.sample_rate * 1000) + # pad = int((self.win_length / self.hop_length) / 2) + # f0 = [0.0] * pad + f0 + [0.0] * pad + # f0 = np.array(f0, dtype=np.float32) + + f0, _, _ = librosa.pyin( + x, + fmin=65 if self.mel_fmin == 0 else self.mel_fmin, + fmax=self.mel_fmax, + frame_length=self.win_length, + sr=self.sample_rate, + fill_na=0.0, + ) # f02 = librosa.yin( # x,