diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 805f36d6..100b8fb3 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -756,7 +756,7 @@ class FastPitchLoss(nn.Module): loss = loss + self.aligner_loss_alpha * aligner_loss return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss - if self.binary_alignment_loss_alpha > 0: + if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None: binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft) loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss diff --git a/TTS/tts/models/fast_pitch.py b/TTS/tts/models/fast_pitch.py index b8f346c7..352aebfa 100644 --- a/TTS/tts/models/fast_pitch.py +++ b/TTS/tts/models/fast_pitch.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Tuple +from typing import Dict, Tuple import torch from coqpit import Coqpit @@ -8,6 +8,7 @@ from torch.cuda.amp.autocast_mode import autocast from TTS.tts.layers.feed_forward.decoder import Decoder from TTS.tts.layers.feed_forward.encoder import Encoder +from TTS.tts.layers.generic.aligner import AlignmentNetwork from TTS.tts.layers.generic.pos_encoding import PositionalEncoding from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path @@ -15,87 +16,101 @@ from TTS.tts.models.base_tts import BaseTTS from TTS.tts.utils.data import sequence_mask from TTS.tts.utils.visual import plot_alignment, plot_spectrogram from TTS.utils.audio import AudioProcessor - - -class AlignmentEncoder(torch.nn.Module): - def __init__( - self, - in_query_channels=80, - in_key_channels=512, - attn_channels=80, - temperature=0.0005, - ): - super().__init__() - self.temperature = temperature - self.softmax = torch.nn.Softmax(dim=3) - self.log_softmax = torch.nn.LogSoftmax(dim=3) - - self.key_layer = nn.Sequential( - nn.Conv1d( - in_key_channels, - in_key_channels * 2, - kernel_size=3, - padding=1, - bias=True, - ), - torch.nn.ReLU(), - nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True), - ) - - self.query_layer = nn.Sequential( - nn.Conv1d( - in_query_channels, - in_query_channels * 2, - kernel_size=3, - padding=1, - bias=True, - ), - torch.nn.ReLU(), - nn.Conv1d(in_query_channels * 2, in_query_channels, kernel_size=1, padding=0, bias=True), - torch.nn.ReLU(), - nn.Conv1d(in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True), - ) - - def forward( - self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None - ) -> Tuple[torch.tensor, torch.tensor]: - """Forward pass of the aligner encoder. - Shapes: - - queries: :math:`[B, C, T_de]` - - keys: :math:`[B, C_emb, T_en]` - - mask: :math:`[B, T_de]` - Output: - attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask. - attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities. - """ - key_out = self.key_layer(keys) - query_out = self.query_layer(queries) - attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2 - attn_factor = -self.temperature * attn_factor.sum(1, keepdim=True) - if attn_prior is not None: - attn_logp = self.log_softmax(attn_factor) + torch.log(attn_prior[:, None] + 1e-8) - if mask is not None: - attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf")) - attn = self.softmax(attn_logp) - return attn, attn_logp +from TTS.utils.soft_dtw import SoftDTW @dataclass class FastPitchArgs(Coqpit): + """Fast Pitch Model arguments. + + Args: + + num_chars (int): + Number of characters in the vocabulary. Defaults to 100. + + out_channels (int): + Number of output channels. Defaults to 80. + + hidden_channels (int): + Number of base hidden channels of the model. Defaults to 512. + + num_speakers (int): + Number of speakers for the speaker embedding layer. Defaults to 0. + + duration_predictor_hidden_channels (int): + Number of hidden channels in the duration predictor. Defaults to 256. + + duration_predictor_dropout_p (float): + Dropout rate for the duration predictor. Defaults to 0.1. + + duration_predictor_kernel_size (int): + Kernel size of conv layers in the duration predictor. Defaults to 3. + + pitch_predictor_hidden_channels (int): + Number of hidden channels in the pitch predictor. Defaults to 256. + + pitch_predictor_dropout_p (float): + Dropout rate for the pitch predictor. Defaults to 0.1. + + pitch_predictor_kernel_size (int): + Kernel size of conv layers in the pitch predictor. Defaults to 3. + + pitch_embedding_kernel_size (int): + Kernel size of the projection layer in the pitch predictor. Defaults to 3. + + positional_encoding (bool): + Whether to use positional encoding. Defaults to True. + + positional_encoding_use_scale (bool): + Whether to use a learnable scale coeff in the positional encoding. Defaults to True. + + length_scale (int): + Length scale that multiplies the predicted durations. Larger values result slower speech. Defaults to 1.0. + + encoder_type (str): + Type of the encoder module. One of the encoders available in :class:`TTS.tts.layers.feed_forward.encoder`. + Defaults to `fftransformer` as in the paper. + + encoder_params (dict): + Parameters of the encoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}``` + + decoder_type (str): + Type of the decoder module. One of the decoders available in :class:`TTS.tts.layers.feed_forward.decoder`. + Defaults to `fftransformer` as in the paper. + + decoder_params (str): + Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}``` + + use_d_vetor (bool): + Whether to use precomputed d-vectors for multi-speaker training. Defaults to False. + + d_vector_dim (int): + Number of channels of the d-vectors. Defaults to 0. + + detach_duration_predictor (bool): + Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss + does not pass to the earlier layers. Defaults to True. + + max_duration (int): + Maximum duration accepted by the model. Defaults to 75. + + use_aligner (bool): + Use aligner network to learn the text to speech alignment. Defaults to True. + """ + num_chars: int = None out_channels: int = 80 - hidden_channels: int = 256 + hidden_channels: int = 384 num_speakers: int = 0 duration_predictor_hidden_channels: int = 256 - duration_predictor_dropout: float = 0.1 duration_predictor_kernel_size: int = 3 duration_predictor_dropout_p: float = 0.1 pitch_predictor_hidden_channels: int = 256 - pitch_predictor_dropout: float = 0.1 pitch_predictor_kernel_size: int = 3 pitch_predictor_dropout_p: float = 0.1 pitch_embedding_kernel_size: int = 3 positional_encoding: bool = True + poisitonal_encoding_use_scale: bool = True length_scale: int = 1 encoder_type: str = "fftransformer" encoder_params: dict = field( @@ -109,14 +124,16 @@ class FastPitchArgs(Coqpit): d_vector_dim: int = 0 detach_duration_predictor: bool = False max_duration: int = 75 - use_gt_duration: bool = True use_aligner: bool = True class FastPitch(BaseTTS): """FastPitch model. Very similart to SpeedySpeech model but with pitch prediction. - Paper abstract: + Paper:: + https://arxiv.org/abs/2006.06873 + + Paper abstract:: We present FastPitch, a fully-parallel text-to-speech model based on FastSpeech, conditioned on fundamental frequency contours. The model predicts pitch contours during inference. By altering these predictions, the generated speech can be more expressive, better match the semantic of the utterance, and in the end @@ -126,9 +143,6 @@ class FastPitch(BaseTTS): and FastPitch retains the favorable, fully-parallel Transformer architecture, with over 900x real-time factor for mel-spectrogram synthesis of a typical utterance." - Notes: - TODO - Args: config (Coqpit): Model coqpit class. @@ -143,95 +157,138 @@ class FastPitch(BaseTTS): super().__init__() - if "characters" in config: - # loading from FasrPitchConfig - _, self.config, num_chars = self.get_characters(config) - config.model_args.num_chars = num_chars - args = self.config.model_args - else: - # loading from FastPitchArgs + # don't use isintance not to import recursively + if config.__class__.__name__ == "FastPitchConfig": + if "characters" in config: + # loading from FasrPitchConfig + _, self.config, num_chars = self.get_characters(config) + config.model_args.num_chars = num_chars + self.args = self.config.model_args + else: + # loading from FastPitchArgs + self.config = config + self.args = config.model_args + elif isinstance(config, FastPitchArgs): + self.args = config self.config = config - args = config + else: + raise ValueError("config must be either a VitsConfig or Vitsself.args") - self.max_duration = args.max_duration - self.use_gt_duration = args.use_gt_duration - self.use_aligner = args.use_aligner + self.max_duration = self.args.max_duration + self.use_aligner = self.args.use_aligner + self.use_binary_alignment_loss = False - self.length_scale = float(args.length_scale) if isinstance(args.length_scale, int) else args.length_scale - - self.emb = nn.Embedding(config.model_args.num_chars, config.model_args.hidden_channels) - - self.encoder = Encoder( - config.model_args.hidden_channels, - config.model_args.hidden_channels, - config.model_args.encoder_type, - config.model_args.encoder_params, - config.model_args.d_vector_dim, + self.length_scale = ( + float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale ) - if config.model_args.positional_encoding: - self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels) + self.emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels) + + self.encoder = Encoder( + self.args.hidden_channels, + self.args.hidden_channels, + self.args.encoder_type, + self.args.encoder_params, + self.args.d_vector_dim, + ) + + if self.args.positional_encoding: + self.pos_encoder = PositionalEncoding(self.args.hidden_channels) self.decoder = Decoder( - config.model_args.out_channels, - config.model_args.hidden_channels, - config.model_args.decoder_type, - config.model_args.decoder_params, + self.args.out_channels, + self.args.hidden_channels, + self.args.decoder_type, + self.args.decoder_params, ) self.duration_predictor = DurationPredictor( - config.model_args.hidden_channels + config.model_args.d_vector_dim, - config.model_args.duration_predictor_hidden_channels, - config.model_args.duration_predictor_kernel_size, - config.model_args.duration_predictor_dropout_p, + self.args.hidden_channels + self.args.d_vector_dim, + self.args.duration_predictor_hidden_channels, + self.args.duration_predictor_kernel_size, + self.args.duration_predictor_dropout_p, ) self.pitch_predictor = DurationPredictor( - config.model_args.hidden_channels + config.model_args.d_vector_dim, - config.model_args.pitch_predictor_hidden_channels, - config.model_args.pitch_predictor_kernel_size, - config.model_args.pitch_predictor_dropout_p, + self.args.hidden_channels + self.args.d_vector_dim, + self.args.pitch_predictor_hidden_channels, + self.args.pitch_predictor_kernel_size, + self.args.pitch_predictor_dropout_p, ) self.pitch_emb = nn.Conv1d( 1, - config.model_args.hidden_channels, - kernel_size=config.model_args.pitch_embedding_kernel_size, - padding=int((config.model_args.pitch_embedding_kernel_size - 1) / 2), + self.args.hidden_channels, + kernel_size=self.args.pitch_embedding_kernel_size, + padding=int((self.args.pitch_embedding_kernel_size - 1) / 2), ) - if config.model_args.num_speakers > 1 and not config.model_args.use_d_vector: + if self.args.num_speakers > 1 and not self.args.use_d_vector: # speaker embedding layer - self.emb_g = nn.Embedding(config.model_args.num_speakers, config.model_args.d_vector_dim) + self.emb_g = nn.Embedding(self.args.num_speakers, self.args.d_vector_dim) nn.init.uniform_(self.emb_g.weight, -0.1, 0.1) - if config.model_args.d_vector_dim > 0 and config.model_args.d_vector_dim != config.model_args.hidden_channels: - self.proj_g = nn.Conv1d(config.model_args.d_vector_dim, config.model_args.hidden_channels, 1) + if self.args.d_vector_dim > 0 and self.args.d_vector_dim != self.args.hidden_channels: + self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1) - if args.use_aligner: - self.aligner = AlignmentEncoder(args.out_channels, args.hidden_channels) + if self.args.use_aligner: + self.aligner = AlignmentNetwork(in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels) @staticmethod - def expand_encoder_outputs(en, dr, x_mask, y_mask): + def generate_attn(dr, x_mask, y_mask=None): + """Generate an attention mask from the durations. + + Shapes + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` + """ + # compute decode mask from the durations + if y_mask is None: + y_lengths = dr.sum(1).long() + y_lengths[y_lengths < 1] = 1 + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype) + attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) + attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype) + return attn + + def expand_encoder_outputs(self, en, dr, x_mask, y_mask): """Generate attention alignment map from durations and expand encoder outputs - Example: - encoder output: [a,b,c,d] - durations: [1, 3, 2, 1] + Shapes + - en: :math:`(B, D_{en}, T_{en})` + - dr: :math:`(B, T_{en})` + - x_mask: :math:`(B, T_{en})` + - y_mask: :math:`(B, T_{de})` - expanded: [a, b, b, b, c, c, d] - attention map: [[0, 0, 0, 0, 0, 0, 1], - [0, 0, 0, 0, 1, 1, 0], - [0, 1, 1, 1, 0, 0, 0], - [1, 0, 0, 0, 0, 0, 0]] + Examples: + - encoder output: :math:`[a,b,c,d]` + - durations: :math:`[1, 3, 2, 1]` + + - expanded: :math:`[a, b, b, b, c, c, d]` + - attention map: :math:`[[0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0]]` """ - attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) - attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype) - o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2) + attn = self.generate_attn(dr, x_mask, y_mask) + o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2) return o_en_ex, attn def format_durations(self, o_dr_log, x_mask): + """Format predicted durations. + 1. Convert to linear scale from log scale + 2. Apply the length scale for speed adjustment + 3. Apply masking. + 4. Cast 0 durations to 1. + 5. Round the duration values. + + Args: + o_dr_log: Log scale durations. + x_mask: Input text mask. + + Shapes: + - o_dr_log: :math:`(B, T_{de})` + - x_mask: :math:`(B, T_{en})` + """ o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale o_dr[o_dr < 1] = 1.0 o_dr = torch.round(o_dr) @@ -249,22 +306,39 @@ class FastPitch(BaseTTS): g = self.proj_g(g) return x + g - def _forward_encoder(self, x, x_lengths, g=None): + def _forward_encoder( + self, x: torch.LongTensor, x_mask:torch.FloatTensor, g: torch.FloatTensor = None + ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Encoding forward pass. + + 1. Embed speaker IDs if multi-speaker mode. + 2. Embed character sequences. + 3. Run the encoder network. + 4. Concat speaker embedding to the encoder output for the duration predictor. + + Args: + x (torch.LongTensor): Input sequence IDs. + x_mask (torch.FloatTensor): Input squence mask. + g (torch.FloatTensor, optional): Conditioning vectors. In general speaker embeddings. Defaults to None. + + Returns: + Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor, torch.tensor]: + encoder output, encoder output for the duration predictor, input sequence mask, speaker embeddings, + character embeddings + + Shapes: + - x: :math:`(B, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - g: :math:`(B, C)` + """ if hasattr(self, "emb_g"): g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1] - if g is not None: g = g.unsqueeze(-1) - # [B, T, C] x_emb = self.emb(x) - - # compute sequence masks - x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype) - # encoder pass o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask) - # speaker conditioning for duration predictor if g is not None: o_en_dp = self._concat_speaker_embedding(o_en, g) @@ -272,8 +346,33 @@ class FastPitch(BaseTTS): o_en_dp = o_en return o_en, o_en_dp, x_mask, g, x_emb - def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g): - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype) + def _forward_decoder( + self, + o_en: torch.FloatTensor, + dr: torch.IntTensor, + x_mask: torch.FloatTensor, + y_lengths: torch.IntTensor, + g: torch.FloatTensor, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Decoding forward pass. + + 1. Compute the decoder output mask + 2. Expand encoder output with the durations. + 3. Apply position encoding. + 4. Add speaker embeddings if multi-speaker mode. + 5. Run the decoder. + + Args: + o_en (torch.FloatTensor): Encoder output. + dr (torch.IntTensor): Ground truth durations or alignment network durations. + x_mask (torch.IntTensor): Input sequence mask. + y_lengths (torch.IntTensor): Output sequence lengths. + g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations. + """ + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype) # expand o_en with durations o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask) # positional encoding @@ -286,7 +385,34 @@ class FastPitch(BaseTTS): o_de = self.decoder(o_en_ex, y_mask, g=g) return o_de.transpose(1, 2), attn.transpose(1, 2) - def _forward_pitch_predictor(self, o_en, x_mask, pitch=None, dr=None): + def _forward_pitch_predictor( + self, + o_en: torch.FloatTensor, + x_mask: torch.IntTensor, + pitch: torch.FloatTensor = None, + dr: torch.IntTensor = None, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + """Pitch predictor forward pass. + + 1. Predict pitch from encoder outputs. + 2. In training - Compute average pitch values for each input character from the ground truth pitch values. + 3. Embed average pitch values. + + Args: + o_en (torch.FloatTensor): Encoder output. + x_mask (torch.IntTensor): Input sequence mask. + pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None. + dr (torch.IntTensor, optional): Ground truth durations. Defaults to None. + + Returns: + Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction. + + Shapes: + - o_en: :math:`(B, C, T_{en})` + - x_mask: :math:`(B, 1, T_{en})` + - pitch: :math:`(B, 1, T_{de})` + - dr: :math:`(B, T_{en})` + """ o_pitch = self.pitch_predictor(o_en, x_mask) if pitch is not None: avg_pitch = average_pitch(pitch, dr) @@ -295,49 +421,111 @@ class FastPitch(BaseTTS): o_pitch_emb = self.pitch_emb(o_pitch) return o_pitch_emb, o_pitch - def _forward_aligner(self, y, embedding, x_mask, y_mask): + def _forward_aligner( + self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor + ) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + """Aligner forward pass. + + 1. Compute a mask to apply to the attention map. + 2. Run the alignment network. + 3. Apply MAS to compute the hard alignment map. + 4. Compute the durations from the hard alignment map. + + Args: + x (torch.FloatTensor): Input sequence. + y (torch.FloatTensor): Output sequence. + x_mask (torch.IntTensor): Input sequence mask. + y_mask (torch.IntTensor): Output sequence mask. + + Returns: + Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: + Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials, + hard alignment map. + + Shapes: + - x: :math:`[B, T_en, C_en]` + - y: :math:`[B, T_de, C_de]` + - x_mask: :math:`[B, 1, T_en]` + - y_mask: :math:`[B, 1, T_de]` + + - o_alignment_dur: :math:`[B, T_en]` + - alignment_soft: :math:`[B, T_en, T_de]` + - alignment_logprob: :math:`[B, 1, T_de, T_en]` + - alignment_mas: :math:`[B, T_en, T_de]` + """ attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2) - alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), embedding.transpose(1, 2), x_mask, None) + alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None) alignment_mas = maximum_path( alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous() ) - o_alignment_dur = torch.sum(alignment_mas, -1) - return o_alignment_dur, alignment_logprob, alignment_mas + o_alignment_dur = torch.sum(alignment_mas, -1).int() + alignment_soft = alignment_soft.squeeze(1).transpose(1, 2) + return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas def forward( - self, x, x_lengths, y_lengths, y=None, dr=None, pitch=None, aux_input={"d_vectors": 0, "speaker_ids": None} - ): # pylint: disable=unused-argument - """ + self, + x: torch.LongTensor, + x_lengths: torch.LongTensor, + y_lengths: torch.LongTensor, + y: torch.FloatTensor = None, + dr: torch.IntTensor = None, + pitch: torch.FloatTensor = None, + aux_input: Dict = {"d_vectors": 0, "speaker_ids": None}, # pylint: disable=unused-argument + ) -> Dict: + """Model's forward pass. + + Args: + x (torch.LongTensor): Input character sequences. + x_lengths (torch.LongTensor): Input sequence lengths. + y_lengths (torch.LongTensor): Output sequnce lengths. Defaults to None. + y (torch.FloatTensor): Spectrogram frames. Defaults to None. + dr (torch.IntTensor): Character durations over the spectrogram frames. Defaults to None. + pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Defaults to None. + aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": 0, "speaker_ids": None}`. + Shapes: - x: :math:`[B, T_max]` - x_lengths: :math:`[B]` - y_lengths: :math:`[B]` - y: :math:`[B, T_max2]` - dr: :math:`[B, T_max]` - g: :math:`[B, C]` - pitch: :math:`[B, 1, T]` + - x: :math:`[B, T_max]` + - x_lengths: :math:`[B]` + - y_lengths: :math:`[B]` + - y: :math:`[B, T_max2]` + - dr: :math:`[B, T_max]` + - g: :math:`[B, C]` + - pitch: :math:`[B, 1, T]` """ g = aux_input["d_vectors"] if "d_vectors" in aux_input else None - y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x.dtype) - o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_lengths, g) - if self.config.model_args.detach_duration_predictor: + # compute sequence masks + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(y.dtype) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(y.dtype) + # encoder pass + o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g) + # duration predictor pass + if self.args.detach_duration_predictor: o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask) else: o_dr_log = self.duration_predictor(o_en_dp, x_mask) o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration) + # generate attn mask from predicted durations + o_attn = self.generate_attn(o_dr.squeeze(1), x_mask) + # aligner pass if self.use_aligner: - o_alignment_dur, alignment_logprob, alignment_mas = self._forward_aligner(y, x_emb, x_mask, y_mask) + o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner( + x_emb, y, x_mask, y_mask + ) dr = o_alignment_dur + # pitch predictor pass o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr) o_en = o_en + o_pitch_emb - o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g) + # decoder pass + o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g) outputs = { "model_outputs": o_de, "durations_log": o_dr_log.squeeze(1), "durations": o_dr.squeeze(1), + "attn_durations": o_attn, # for visualization "pitch": o_pitch, "pitch_gt": avg_pitch, "alignments": attn, + "alignment_soft": alignment_soft.transpose(1, 2), "alignment_mas": alignment_mas.transpose(1, 2), "o_alignment_dur": o_alignment_dur, "alignment_logprob": alignment_logprob, @@ -346,43 +534,33 @@ class FastPitch(BaseTTS): @torch.no_grad() def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument - """ + """Model's inference pass. + + Args: + x (torch.LongTensor): Input character sequence. + aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`. + Shapes: - x: [B, T_max] - x_lengths: [B] - g: [B, C] + - x: [B, T_max] + - x_lengths: [B] + - g: [B, C] """ g = aux_input["d_vectors"] if "d_vectors" in aux_input else None x_lengths = torch.tensor(x.shape[1:2]).to(x.device) - # input sequence should be greated than the max convolution size - inference_padding = 5 - if x.shape[1] < 13: - inference_padding += 13 - x.shape[1] - # pad input to prevent dropping the last word - x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0) - o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g) + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float() + # encoder pass + o_en, o_en_dp, x_mask, g, _ = self._forward_encoder(x, x_mask, g) # duration predictor pass o_dr_log = self.duration_predictor(o_en_dp, x_mask) o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1) + y_lengths = o_dr.sum(1) # pitch predictor pass o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask) - # if pitch_transform is not None: - # if self.pitch_std[0] == 0.0: - # # XXX LJSpeech-1.1 defaults - # mean, std = 218.14, 67.24 - # else: - # mean, std = self.pitch_mean[0], self.pitch_std[0] - # pitch_pred = pitch_transform(pitch_pred, enc_mask.sum(dim=(1,2)), mean, std) - - # if pitch_tgt is None: - # pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2) - # else: - # pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2) o_en = o_en + o_pitch_emb - y_lengths = o_dr.sum(1) - o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g) + # decoder pass + o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=g) outputs = { - "model_outputs": o_de.transpose(1, 2), + "model_outputs": o_de, "alignments": attn, "pitch": o_pitch, "durations_log": o_dr_log, @@ -398,33 +576,35 @@ class FastPitch(BaseTTS): d_vectors = batch["d_vectors"] speaker_ids = batch["speaker_ids"] durations = batch["durations"] - aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids} + # forward pass outputs = self.forward( text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input ) - + # use aligner's output as the duration target if self.use_aligner: durations = outputs["o_alignment_dur"] - - with autocast(enabled=False): # use float32 for the criterion + # use float32 in AMP + with autocast(enabled=False): # compute loss loss_dict = criterion( - outputs["model_outputs"], - mel_input, - mel_lengths, - outputs["durations_log"], - durations, - outputs["pitch"], - outputs["pitch_gt"], - text_lengths, - outputs["alignment_logprob"], + decoder_output=outputs["model_outputs"], + decoder_target=mel_input, + decoder_output_lens=mel_lengths, + dur_output=outputs["durations_log"], + dur_target=durations, + pitch_output=outputs["pitch"], + pitch_target=outputs["pitch_gt"], + input_lens=text_lengths, + alignment_logprob=outputs["alignment_logprob"], + alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None, + alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None ) - # compute duration error durations_pred = outputs["durations"] duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum() loss_dict["duration_error"] = duration_error + return outputs, loss_dict def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use @@ -442,9 +622,10 @@ class FastPitch(BaseTTS): "alignment": plot_alignment(align_img, output_fig=False), } - if self.config.model_args.use_aligner and self.training: - alignment_mas = outputs["alignment_mas"][0].data.cpu().numpy() - figures["alignment_mas"] = plot_alignment(alignment_mas, output_fig=False) + # plot the attention mask computed from the predicted durations + if "attn_durations" in outputs: + alignments_hat = outputs["attn_durations"][0].data.cpu().numpy() + figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False) # Sample audio train_audio = ap.inv_melspectrogram(pred_spec.T) @@ -470,8 +651,20 @@ class FastPitch(BaseTTS): return FastPitchLoss(self.config) + def on_train_step_start(self, trainer): + """Enable binary alignment loss when needed""" + if trainer.total_steps_done > self.config.binary_align_loss_start_step: + self.use_binary_alignment_loss = True + def average_pitch(pitch, durs): + """Compute the average pitch value for each input character based on the durations. + + Shapes: + - pitch: :math:`[B, 1, T_de]` + - durs: :math:`[B, T_en]` + """ + durs_cums_ends = torch.cumsum(durs, dim=1).long() durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0)) pitch_nonzero_cums = torch.nn.functional.pad(torch.cumsum(pitch != 0.0, dim=2), (1, 0)) diff --git a/docs/source/index.md b/docs/source/index.md index 77d198c0..d842f894 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -45,6 +45,7 @@ models/glow_tts.md models/vits.md + models/fast_pitch.md .. toctree:: :maxdepth: 2 diff --git a/recipes/ljspeech/fast_pitch/train_fast_pitch.py b/recipes/ljspeech/fast_pitch/train_fast_pitch.py index 63f50dd9..5c9e67da 100644 --- a/recipes/ljspeech/fast_pitch/train_fast_pitch.py +++ b/recipes/ljspeech/fast_pitch/train_fast_pitch.py @@ -14,10 +14,11 @@ dataset_config = BaseDatasetConfig( # meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"), path=os.path.join(output_path, "../LJSpeech-1.1/"), ) + audio_config = BaseAudioConfig( sample_rate=22050, - do_trim_silence=False, - trim_db=0.0, + do_trim_silence=True, + trim_db=60.0, signal_norm=False, mel_fmin=0.0, mel_fmax=8000, @@ -26,6 +27,7 @@ audio_config = BaseAudioConfig( ref_level_db=20, preemphasis=0.0, ) + config = FastPitchConfig( run_name="fast_pitch_ljspeech", audio=audio_config, @@ -33,6 +35,7 @@ config = FastPitchConfig( eval_batch_size=16, num_loader_workers=8, num_eval_loader_workers=4, + compute_input_seq_cache=True, compute_f0=True, f0_cache_path=os.path.join(output_path, "f0_cache"), run_eval=True, @@ -45,6 +48,8 @@ config = FastPitchConfig( print_step=50, print_eval=False, mixed_precision=False, + sort_by_audio_len=True, + max_seq_len=500000, output_path=output_path, datasets=[dataset_config], )