FastPitch refactor and commenting

pull/792/head
Eren Gölge 2021-09-03 13:26:49 +00:00
parent 59b24e66cf
commit 2bf9e83c49
4 changed files with 403 additions and 204 deletions

View File

@ -756,7 +756,7 @@ class FastPitchLoss(nn.Module):
loss = loss + self.aligner_loss_alpha * aligner_loss
return_dict["loss_aligner"] = self.aligner_loss_alpha * aligner_loss
if self.binary_alignment_loss_alpha > 0:
if self.binary_alignment_loss_alpha > 0 and alignment_hard is not None:
binary_alignment_loss = self._binary_alignment_loss(alignment_hard, alignment_soft)
loss = loss + self.binary_alignment_loss_alpha * binary_alignment_loss
return_dict["loss_binary_alignment"] = self.binary_alignment_loss_alpha * binary_alignment_loss

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Tuple
from typing import Dict, Tuple
import torch
from coqpit import Coqpit
@ -8,6 +8,7 @@ from torch.cuda.amp.autocast_mode import autocast
from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.layers.generic.aligner import AlignmentNetwork
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
@ -15,87 +16,101 @@ from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
class AlignmentEncoder(torch.nn.Module):
def __init__(
self,
in_query_channels=80,
in_key_channels=512,
attn_channels=80,
temperature=0.0005,
):
super().__init__()
self.temperature = temperature
self.softmax = torch.nn.Softmax(dim=3)
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.key_layer = nn.Sequential(
nn.Conv1d(
in_key_channels,
in_key_channels * 2,
kernel_size=3,
padding=1,
bias=True,
),
torch.nn.ReLU(),
nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True),
)
self.query_layer = nn.Sequential(
nn.Conv1d(
in_query_channels,
in_query_channels * 2,
kernel_size=3,
padding=1,
bias=True,
),
torch.nn.ReLU(),
nn.Conv1d(in_query_channels * 2, in_query_channels, kernel_size=1, padding=0, bias=True),
torch.nn.ReLU(),
nn.Conv1d(in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True),
)
def forward(
self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None
) -> Tuple[torch.tensor, torch.tensor]:
"""Forward pass of the aligner encoder.
Shapes:
- queries: :math:`[B, C, T_de]`
- keys: :math:`[B, C_emb, T_en]`
- mask: :math:`[B, T_de]`
Output:
attn (torch.tensor): :math:`[B, 1, T_en, T_de]` soft attention mask.
attn_logp (torch.tensor): :math:`[ßB, 1, T_en , T_de]` log probabilities.
"""
key_out = self.key_layer(keys)
query_out = self.query_layer(queries)
attn_factor = (query_out[:, :, :, None] - key_out[:, :, None]) ** 2
attn_factor = -self.temperature * attn_factor.sum(1, keepdim=True)
if attn_prior is not None:
attn_logp = self.log_softmax(attn_factor) + torch.log(attn_prior[:, None] + 1e-8)
if mask is not None:
attn_logp.data.masked_fill_(~mask.bool().unsqueeze(2), -float("inf"))
attn = self.softmax(attn_logp)
return attn, attn_logp
from TTS.utils.soft_dtw import SoftDTW
@dataclass
class FastPitchArgs(Coqpit):
"""Fast Pitch Model arguments.
Args:
num_chars (int):
Number of characters in the vocabulary. Defaults to 100.
out_channels (int):
Number of output channels. Defaults to 80.
hidden_channels (int):
Number of base hidden channels of the model. Defaults to 512.
num_speakers (int):
Number of speakers for the speaker embedding layer. Defaults to 0.
duration_predictor_hidden_channels (int):
Number of hidden channels in the duration predictor. Defaults to 256.
duration_predictor_dropout_p (float):
Dropout rate for the duration predictor. Defaults to 0.1.
duration_predictor_kernel_size (int):
Kernel size of conv layers in the duration predictor. Defaults to 3.
pitch_predictor_hidden_channels (int):
Number of hidden channels in the pitch predictor. Defaults to 256.
pitch_predictor_dropout_p (float):
Dropout rate for the pitch predictor. Defaults to 0.1.
pitch_predictor_kernel_size (int):
Kernel size of conv layers in the pitch predictor. Defaults to 3.
pitch_embedding_kernel_size (int):
Kernel size of the projection layer in the pitch predictor. Defaults to 3.
positional_encoding (bool):
Whether to use positional encoding. Defaults to True.
positional_encoding_use_scale (bool):
Whether to use a learnable scale coeff in the positional encoding. Defaults to True.
length_scale (int):
Length scale that multiplies the predicted durations. Larger values result slower speech. Defaults to 1.0.
encoder_type (str):
Type of the encoder module. One of the encoders available in :class:`TTS.tts.layers.feed_forward.encoder`.
Defaults to `fftransformer` as in the paper.
encoder_params (dict):
Parameters of the encoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}```
decoder_type (str):
Type of the decoder module. One of the decoders available in :class:`TTS.tts.layers.feed_forward.decoder`.
Defaults to `fftransformer` as in the paper.
decoder_params (str):
Parameters of the decoder module. Defaults to ```{"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}```
use_d_vetor (bool):
Whether to use precomputed d-vectors for multi-speaker training. Defaults to False.
d_vector_dim (int):
Number of channels of the d-vectors. Defaults to 0.
detach_duration_predictor (bool):
Detach the input to the duration predictor from the earlier computation graph so that the duraiton loss
does not pass to the earlier layers. Defaults to True.
max_duration (int):
Maximum duration accepted by the model. Defaults to 75.
use_aligner (bool):
Use aligner network to learn the text to speech alignment. Defaults to True.
"""
num_chars: int = None
out_channels: int = 80
hidden_channels: int = 256
hidden_channels: int = 384
num_speakers: int = 0
duration_predictor_hidden_channels: int = 256
duration_predictor_dropout: float = 0.1
duration_predictor_kernel_size: int = 3
duration_predictor_dropout_p: float = 0.1
pitch_predictor_hidden_channels: int = 256
pitch_predictor_dropout: float = 0.1
pitch_predictor_kernel_size: int = 3
pitch_predictor_dropout_p: float = 0.1
pitch_embedding_kernel_size: int = 3
positional_encoding: bool = True
poisitonal_encoding_use_scale: bool = True
length_scale: int = 1
encoder_type: str = "fftransformer"
encoder_params: dict = field(
@ -109,14 +124,16 @@ class FastPitchArgs(Coqpit):
d_vector_dim: int = 0
detach_duration_predictor: bool = False
max_duration: int = 75
use_gt_duration: bool = True
use_aligner: bool = True
class FastPitch(BaseTTS):
"""FastPitch model. Very similart to SpeedySpeech model but with pitch prediction.
Paper abstract:
Paper::
https://arxiv.org/abs/2006.06873
Paper abstract::
We present FastPitch, a fully-parallel text-to-speech model based on FastSpeech, conditioned on fundamental
frequency contours. The model predicts pitch contours during inference. By altering these predictions,
the generated speech can be more expressive, better match the semantic of the utterance, and in the end
@ -126,9 +143,6 @@ class FastPitch(BaseTTS):
and FastPitch retains the favorable, fully-parallel Transformer architecture, with over 900x real-time
factor for mel-spectrogram synthesis of a typical utterance."
Notes:
TODO
Args:
config (Coqpit): Model coqpit class.
@ -143,95 +157,138 @@ class FastPitch(BaseTTS):
super().__init__()
if "characters" in config:
# loading from FasrPitchConfig
_, self.config, num_chars = self.get_characters(config)
config.model_args.num_chars = num_chars
args = self.config.model_args
else:
# loading from FastPitchArgs
# don't use isintance not to import recursively
if config.__class__.__name__ == "FastPitchConfig":
if "characters" in config:
# loading from FasrPitchConfig
_, self.config, num_chars = self.get_characters(config)
config.model_args.num_chars = num_chars
self.args = self.config.model_args
else:
# loading from FastPitchArgs
self.config = config
self.args = config.model_args
elif isinstance(config, FastPitchArgs):
self.args = config
self.config = config
args = config
else:
raise ValueError("config must be either a VitsConfig or Vitsself.args")
self.max_duration = args.max_duration
self.use_gt_duration = args.use_gt_duration
self.use_aligner = args.use_aligner
self.max_duration = self.args.max_duration
self.use_aligner = self.args.use_aligner
self.use_binary_alignment_loss = False
self.length_scale = float(args.length_scale) if isinstance(args.length_scale, int) else args.length_scale
self.emb = nn.Embedding(config.model_args.num_chars, config.model_args.hidden_channels)
self.encoder = Encoder(
config.model_args.hidden_channels,
config.model_args.hidden_channels,
config.model_args.encoder_type,
config.model_args.encoder_params,
config.model_args.d_vector_dim,
self.length_scale = (
float(self.args.length_scale) if isinstance(self.args.length_scale, int) else self.args.length_scale
)
if config.model_args.positional_encoding:
self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels)
self.emb = nn.Embedding(self.args.num_chars, self.args.hidden_channels)
self.encoder = Encoder(
self.args.hidden_channels,
self.args.hidden_channels,
self.args.encoder_type,
self.args.encoder_params,
self.args.d_vector_dim,
)
if self.args.positional_encoding:
self.pos_encoder = PositionalEncoding(self.args.hidden_channels)
self.decoder = Decoder(
config.model_args.out_channels,
config.model_args.hidden_channels,
config.model_args.decoder_type,
config.model_args.decoder_params,
self.args.out_channels,
self.args.hidden_channels,
self.args.decoder_type,
self.args.decoder_params,
)
self.duration_predictor = DurationPredictor(
config.model_args.hidden_channels + config.model_args.d_vector_dim,
config.model_args.duration_predictor_hidden_channels,
config.model_args.duration_predictor_kernel_size,
config.model_args.duration_predictor_dropout_p,
self.args.hidden_channels + self.args.d_vector_dim,
self.args.duration_predictor_hidden_channels,
self.args.duration_predictor_kernel_size,
self.args.duration_predictor_dropout_p,
)
self.pitch_predictor = DurationPredictor(
config.model_args.hidden_channels + config.model_args.d_vector_dim,
config.model_args.pitch_predictor_hidden_channels,
config.model_args.pitch_predictor_kernel_size,
config.model_args.pitch_predictor_dropout_p,
self.args.hidden_channels + self.args.d_vector_dim,
self.args.pitch_predictor_hidden_channels,
self.args.pitch_predictor_kernel_size,
self.args.pitch_predictor_dropout_p,
)
self.pitch_emb = nn.Conv1d(
1,
config.model_args.hidden_channels,
kernel_size=config.model_args.pitch_embedding_kernel_size,
padding=int((config.model_args.pitch_embedding_kernel_size - 1) / 2),
self.args.hidden_channels,
kernel_size=self.args.pitch_embedding_kernel_size,
padding=int((self.args.pitch_embedding_kernel_size - 1) / 2),
)
if config.model_args.num_speakers > 1 and not config.model_args.use_d_vector:
if self.args.num_speakers > 1 and not self.args.use_d_vector:
# speaker embedding layer
self.emb_g = nn.Embedding(config.model_args.num_speakers, config.model_args.d_vector_dim)
self.emb_g = nn.Embedding(self.args.num_speakers, self.args.d_vector_dim)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
if config.model_args.d_vector_dim > 0 and config.model_args.d_vector_dim != config.model_args.hidden_channels:
self.proj_g = nn.Conv1d(config.model_args.d_vector_dim, config.model_args.hidden_channels, 1)
if self.args.d_vector_dim > 0 and self.args.d_vector_dim != self.args.hidden_channels:
self.proj_g = nn.Conv1d(self.args.d_vector_dim, self.args.hidden_channels, 1)
if args.use_aligner:
self.aligner = AlignmentEncoder(args.out_channels, args.hidden_channels)
if self.args.use_aligner:
self.aligner = AlignmentNetwork(in_query_channels=self.args.out_channels, in_key_channels=self.args.hidden_channels)
@staticmethod
def expand_encoder_outputs(en, dr, x_mask, y_mask):
def generate_attn(dr, x_mask, y_mask=None):
"""Generate an attention mask from the durations.
Shapes
- dr: :math:`(B, T_{en})`
- x_mask: :math:`(B, T_{en})`
- y_mask: :math:`(B, T_{de})`
"""
# compute decode mask from the durations
if y_mask is None:
y_lengths = dr.sum(1).long()
y_lengths[y_lengths < 1] = 1
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(dr.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
attn = generate_path(dr, attn_mask.squeeze(1)).to(dr.dtype)
return attn
def expand_encoder_outputs(self, en, dr, x_mask, y_mask):
"""Generate attention alignment map from durations and
expand encoder outputs
Example:
encoder output: [a,b,c,d]
durations: [1, 3, 2, 1]
Shapes
- en: :math:`(B, D_{en}, T_{en})`
- dr: :math:`(B, T_{en})`
- x_mask: :math:`(B, T_{en})`
- y_mask: :math:`(B, T_{de})`
expanded: [a, b, b, b, c, c, d]
attention map: [[0, 0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0]]
Examples:
- encoder output: :math:`[a,b,c,d]`
- durations: :math:`[1, 3, 2, 1]`
- expanded: :math:`[a, b, b, b, c, c, d]`
- attention map: :math:`[[0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 1, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0]]`
"""
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype)
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
attn = self.generate_attn(dr, x_mask, y_mask)
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2).to(en.dtype), en.transpose(1, 2)).transpose(1, 2)
return o_en_ex, attn
def format_durations(self, o_dr_log, x_mask):
"""Format predicted durations.
1. Convert to linear scale from log scale
2. Apply the length scale for speed adjustment
3. Apply masking.
4. Cast 0 durations to 1.
5. Round the duration values.
Args:
o_dr_log: Log scale durations.
x_mask: Input text mask.
Shapes:
- o_dr_log: :math:`(B, T_{de})`
- x_mask: :math:`(B, T_{en})`
"""
o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale
o_dr[o_dr < 1] = 1.0
o_dr = torch.round(o_dr)
@ -249,22 +306,39 @@ class FastPitch(BaseTTS):
g = self.proj_g(g)
return x + g
def _forward_encoder(self, x, x_lengths, g=None):
def _forward_encoder(
self, x: torch.LongTensor, x_mask:torch.FloatTensor, g: torch.FloatTensor = None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Encoding forward pass.
1. Embed speaker IDs if multi-speaker mode.
2. Embed character sequences.
3. Run the encoder network.
4. Concat speaker embedding to the encoder output for the duration predictor.
Args:
x (torch.LongTensor): Input sequence IDs.
x_mask (torch.FloatTensor): Input squence mask.
g (torch.FloatTensor, optional): Conditioning vectors. In general speaker embeddings. Defaults to None.
Returns:
Tuple[torch.tensor, torch.tensor, torch.tensor, torch.tensor, torch.tensor]:
encoder output, encoder output for the duration predictor, input sequence mask, speaker embeddings,
character embeddings
Shapes:
- x: :math:`(B, T_{en})`
- x_mask: :math:`(B, 1, T_{en})`
- g: :math:`(B, C)`
"""
if hasattr(self, "emb_g"):
g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1]
if g is not None:
g = g.unsqueeze(-1)
# [B, T, C]
x_emb = self.emb(x)
# compute sequence masks
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
# encoder pass
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
# speaker conditioning for duration predictor
if g is not None:
o_en_dp = self._concat_speaker_embedding(o_en, g)
@ -272,8 +346,33 @@ class FastPitch(BaseTTS):
o_en_dp = o_en
return o_en, o_en_dp, x_mask, g, x_emb
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
def _forward_decoder(
self,
o_en: torch.FloatTensor,
dr: torch.IntTensor,
x_mask: torch.FloatTensor,
y_lengths: torch.IntTensor,
g: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Decoding forward pass.
1. Compute the decoder output mask
2. Expand encoder output with the durations.
3. Apply position encoding.
4. Add speaker embeddings if multi-speaker mode.
5. Run the decoder.
Args:
o_en (torch.FloatTensor): Encoder output.
dr (torch.IntTensor): Ground truth durations or alignment network durations.
x_mask (torch.IntTensor): Input sequence mask.
y_lengths (torch.IntTensor): Output sequence lengths.
g (torch.FloatTensor): Conditioning vectors. In general speaker embeddings.
Returns:
Tuple[torch.FloatTensor, torch.FloatTensor]: Decoder output, attention map from durations.
"""
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en.dtype)
# expand o_en with durations
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
# positional encoding
@ -286,7 +385,34 @@ class FastPitch(BaseTTS):
o_de = self.decoder(o_en_ex, y_mask, g=g)
return o_de.transpose(1, 2), attn.transpose(1, 2)
def _forward_pitch_predictor(self, o_en, x_mask, pitch=None, dr=None):
def _forward_pitch_predictor(
self,
o_en: torch.FloatTensor,
x_mask: torch.IntTensor,
pitch: torch.FloatTensor = None,
dr: torch.IntTensor = None,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
"""Pitch predictor forward pass.
1. Predict pitch from encoder outputs.
2. In training - Compute average pitch values for each input character from the ground truth pitch values.
3. Embed average pitch values.
Args:
o_en (torch.FloatTensor): Encoder output.
x_mask (torch.IntTensor): Input sequence mask.
pitch (torch.FloatTensor, optional): Ground truth pitch values. Defaults to None.
dr (torch.IntTensor, optional): Ground truth durations. Defaults to None.
Returns:
Tuple[torch.FloatTensor, torch.FloatTensor]: Pitch embedding, pitch prediction.
Shapes:
- o_en: :math:`(B, C, T_{en})`
- x_mask: :math:`(B, 1, T_{en})`
- pitch: :math:`(B, 1, T_{de})`
- dr: :math:`(B, T_{en})`
"""
o_pitch = self.pitch_predictor(o_en, x_mask)
if pitch is not None:
avg_pitch = average_pitch(pitch, dr)
@ -295,49 +421,111 @@ class FastPitch(BaseTTS):
o_pitch_emb = self.pitch_emb(o_pitch)
return o_pitch_emb, o_pitch
def _forward_aligner(self, y, embedding, x_mask, y_mask):
def _forward_aligner(
self, x: torch.FloatTensor, y: torch.FloatTensor, x_mask: torch.IntTensor, y_mask: torch.IntTensor
) -> Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Aligner forward pass.
1. Compute a mask to apply to the attention map.
2. Run the alignment network.
3. Apply MAS to compute the hard alignment map.
4. Compute the durations from the hard alignment map.
Args:
x (torch.FloatTensor): Input sequence.
y (torch.FloatTensor): Output sequence.
x_mask (torch.IntTensor): Input sequence mask.
y_mask (torch.IntTensor): Output sequence mask.
Returns:
Tuple[torch.IntTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
Durations from the hard alignment map, soft alignment potentials, log scale alignment potentials,
hard alignment map.
Shapes:
- x: :math:`[B, T_en, C_en]`
- y: :math:`[B, T_de, C_de]`
- x_mask: :math:`[B, 1, T_en]`
- y_mask: :math:`[B, 1, T_de]`
- o_alignment_dur: :math:`[B, T_en]`
- alignment_soft: :math:`[B, T_en, T_de]`
- alignment_logprob: :math:`[B, 1, T_de, T_en]`
- alignment_mas: :math:`[B, T_en, T_de]`
"""
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), embedding.transpose(1, 2), x_mask, None)
alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), x.transpose(1, 2), x_mask, None)
alignment_mas = maximum_path(
alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous()
)
o_alignment_dur = torch.sum(alignment_mas, -1)
return o_alignment_dur, alignment_logprob, alignment_mas
o_alignment_dur = torch.sum(alignment_mas, -1).int()
alignment_soft = alignment_soft.squeeze(1).transpose(1, 2)
return o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas
def forward(
self, x, x_lengths, y_lengths, y=None, dr=None, pitch=None, aux_input={"d_vectors": 0, "speaker_ids": None}
): # pylint: disable=unused-argument
"""
self,
x: torch.LongTensor,
x_lengths: torch.LongTensor,
y_lengths: torch.LongTensor,
y: torch.FloatTensor = None,
dr: torch.IntTensor = None,
pitch: torch.FloatTensor = None,
aux_input: Dict = {"d_vectors": 0, "speaker_ids": None}, # pylint: disable=unused-argument
) -> Dict:
"""Model's forward pass.
Args:
x (torch.LongTensor): Input character sequences.
x_lengths (torch.LongTensor): Input sequence lengths.
y_lengths (torch.LongTensor): Output sequnce lengths. Defaults to None.
y (torch.FloatTensor): Spectrogram frames. Defaults to None.
dr (torch.IntTensor): Character durations over the spectrogram frames. Defaults to None.
pitch (torch.FloatTensor): Pitch values for each spectrogram frame. Defaults to None.
aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": 0, "speaker_ids": None}`.
Shapes:
x: :math:`[B, T_max]`
x_lengths: :math:`[B]`
y_lengths: :math:`[B]`
y: :math:`[B, T_max2]`
dr: :math:`[B, T_max]`
g: :math:`[B, C]`
pitch: :math:`[B, 1, T]`
- x: :math:`[B, T_max]`
- x_lengths: :math:`[B]`
- y_lengths: :math:`[B]`
- y: :math:`[B, T_max2]`
- dr: :math:`[B, T_max]`
- g: :math:`[B, C]`
- pitch: :math:`[B, 1, T]`
"""
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x.dtype)
o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_lengths, g)
if self.config.model_args.detach_duration_predictor:
# compute sequence masks
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(y.dtype)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(y.dtype)
# encoder pass
o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_mask, g)
# duration predictor pass
if self.args.detach_duration_predictor:
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
else:
o_dr_log = self.duration_predictor(o_en_dp, x_mask)
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
# generate attn mask from predicted durations
o_attn = self.generate_attn(o_dr.squeeze(1), x_mask)
# aligner pass
if self.use_aligner:
o_alignment_dur, alignment_logprob, alignment_mas = self._forward_aligner(y, x_emb, x_mask, y_mask)
o_alignment_dur, alignment_soft, alignment_logprob, alignment_mas = self._forward_aligner(
x_emb, y, x_mask, y_mask
)
dr = o_alignment_dur
# pitch predictor pass
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr)
o_en = o_en + o_pitch_emb
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g)
# decoder pass
o_de, attn = self._forward_decoder(o_en, dr, x_mask, y_lengths, g=g)
outputs = {
"model_outputs": o_de,
"durations_log": o_dr_log.squeeze(1),
"durations": o_dr.squeeze(1),
"attn_durations": o_attn, # for visualization
"pitch": o_pitch,
"pitch_gt": avg_pitch,
"alignments": attn,
"alignment_soft": alignment_soft.transpose(1, 2),
"alignment_mas": alignment_mas.transpose(1, 2),
"o_alignment_dur": o_alignment_dur,
"alignment_logprob": alignment_logprob,
@ -346,43 +534,33 @@ class FastPitch(BaseTTS):
@torch.no_grad()
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
"""
"""Model's inference pass.
Args:
x (torch.LongTensor): Input character sequence.
aux_input (Dict): Auxiliary model inputs. Defaults to `{"d_vectors": None, "speaker_ids": None}`.
Shapes:
x: [B, T_max]
x_lengths: [B]
g: [B, C]
- x: [B, T_max]
- x_lengths: [B]
- g: [B, C]
"""
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
# input sequence should be greated than the max convolution size
inference_padding = 5
if x.shape[1] < 13:
inference_padding += 13 - x.shape[1]
# pad input to prevent dropping the last word
x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0)
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype).float()
# encoder pass
o_en, o_en_dp, x_mask, g, _ = self._forward_encoder(x, x_mask, g)
# duration predictor pass
o_dr_log = self.duration_predictor(o_en_dp, x_mask)
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
y_lengths = o_dr.sum(1)
# pitch predictor pass
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask)
# if pitch_transform is not None:
# if self.pitch_std[0] == 0.0:
# # XXX LJSpeech-1.1 defaults
# mean, std = 218.14, 67.24
# else:
# mean, std = self.pitch_mean[0], self.pitch_std[0]
# pitch_pred = pitch_transform(pitch_pred, enc_mask.sum(dim=(1,2)), mean, std)
# if pitch_tgt is None:
# pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2)
# else:
# pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2)
o_en = o_en + o_pitch_emb
y_lengths = o_dr.sum(1)
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
# decoder pass
o_de, attn = self._forward_decoder(o_en, o_dr, x_mask, y_lengths, g=g)
outputs = {
"model_outputs": o_de.transpose(1, 2),
"model_outputs": o_de,
"alignments": attn,
"pitch": o_pitch,
"durations_log": o_dr_log,
@ -398,33 +576,35 @@ class FastPitch(BaseTTS):
d_vectors = batch["d_vectors"]
speaker_ids = batch["speaker_ids"]
durations = batch["durations"]
aux_input = {"d_vectors": d_vectors, "speaker_ids": speaker_ids}
# forward pass
outputs = self.forward(
text_input, text_lengths, mel_lengths, y=mel_input, dr=durations, pitch=pitch, aux_input=aux_input
)
# use aligner's output as the duration target
if self.use_aligner:
durations = outputs["o_alignment_dur"]
with autocast(enabled=False): # use float32 for the criterion
# use float32 in AMP
with autocast(enabled=False):
# compute loss
loss_dict = criterion(
outputs["model_outputs"],
mel_input,
mel_lengths,
outputs["durations_log"],
durations,
outputs["pitch"],
outputs["pitch_gt"],
text_lengths,
outputs["alignment_logprob"],
decoder_output=outputs["model_outputs"],
decoder_target=mel_input,
decoder_output_lens=mel_lengths,
dur_output=outputs["durations_log"],
dur_target=durations,
pitch_output=outputs["pitch"],
pitch_target=outputs["pitch_gt"],
input_lens=text_lengths,
alignment_logprob=outputs["alignment_logprob"],
alignment_soft=outputs["alignment_soft"] if self.use_binary_alignment_loss else None,
alignment_hard=outputs["alignment_mas"] if self.use_binary_alignment_loss else None
)
# compute duration error
durations_pred = outputs["durations"]
duration_error = torch.abs(durations - durations_pred).sum() / text_lengths.sum()
loss_dict["duration_error"] = duration_error
return outputs, loss_dict
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict): # pylint: disable=no-self-use
@ -442,9 +622,10 @@ class FastPitch(BaseTTS):
"alignment": plot_alignment(align_img, output_fig=False),
}
if self.config.model_args.use_aligner and self.training:
alignment_mas = outputs["alignment_mas"][0].data.cpu().numpy()
figures["alignment_mas"] = plot_alignment(alignment_mas, output_fig=False)
# plot the attention mask computed from the predicted durations
if "attn_durations" in outputs:
alignments_hat = outputs["attn_durations"][0].data.cpu().numpy()
figures["alignment_hat"] = plot_alignment(alignments_hat.T, output_fig=False)
# Sample audio
train_audio = ap.inv_melspectrogram(pred_spec.T)
@ -470,8 +651,20 @@ class FastPitch(BaseTTS):
return FastPitchLoss(self.config)
def on_train_step_start(self, trainer):
"""Enable binary alignment loss when needed"""
if trainer.total_steps_done > self.config.binary_align_loss_start_step:
self.use_binary_alignment_loss = True
def average_pitch(pitch, durs):
"""Compute the average pitch value for each input character based on the durations.
Shapes:
- pitch: :math:`[B, 1, T_de]`
- durs: :math:`[B, T_en]`
"""
durs_cums_ends = torch.cumsum(durs, dim=1).long()
durs_cums_starts = torch.nn.functional.pad(durs_cums_ends[:, :-1], (1, 0))
pitch_nonzero_cums = torch.nn.functional.pad(torch.cumsum(pitch != 0.0, dim=2), (1, 0))

View File

@ -45,6 +45,7 @@
models/glow_tts.md
models/vits.md
models/fast_pitch.md
.. toctree::
:maxdepth: 2

View File

@ -14,10 +14,11 @@ dataset_config = BaseDatasetConfig(
# meta_file_attn_mask=os.path.join(output_path, "../LJSpeech-1.1/metadata_attn_mask.txt"),
path=os.path.join(output_path, "../LJSpeech-1.1/"),
)
audio_config = BaseAudioConfig(
sample_rate=22050,
do_trim_silence=False,
trim_db=0.0,
do_trim_silence=True,
trim_db=60.0,
signal_norm=False,
mel_fmin=0.0,
mel_fmax=8000,
@ -26,6 +27,7 @@ audio_config = BaseAudioConfig(
ref_level_db=20,
preemphasis=0.0,
)
config = FastPitchConfig(
run_name="fast_pitch_ljspeech",
audio=audio_config,
@ -33,6 +35,7 @@ config = FastPitchConfig(
eval_batch_size=16,
num_loader_workers=8,
num_eval_loader_workers=4,
compute_input_seq_cache=True,
compute_f0=True,
f0_cache_path=os.path.join(output_path, "f0_cache"),
run_eval=True,
@ -45,6 +48,8 @@ config = FastPitchConfig(
print_step=50,
print_eval=False,
mixed_precision=False,
sort_by_audio_len=True,
max_seq_len=500000,
output_path=output_path,
datasets=[dataset_config],
)