Refactor FastPitchv2

pull/792/head
Eren Gölge 2021-07-26 22:50:34 +02:00
parent e429afbce4
commit 98a7271ce8
2 changed files with 210 additions and 496 deletions

View File

@ -1,23 +1,23 @@
from dataclasses import dataclass, field
import torch
import torch.nn as nn
import torch.nn.functional as F
from coqpit import Coqpit
from matplotlib.pyplot import plot
from torch import nn
from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.layers.glow_tts.duration_predictor import DurationPredictor
from TTS.tts.layers.glow_tts.monotonic_align import generate_path, maximum_path
from TTS.tts.models.base_tts import BaseTTS
from TTS.tts.utils.data import sequence_mask
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
# pylint: disable=dangerous-default-value
class AlignmentEncoder(torch.nn.Module):
"""Module for alignment text and mel spectrogram."""
def __init__(
self,
in_query_channels=80,
@ -31,32 +31,35 @@ class AlignmentEncoder(torch.nn.Module):
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.key_proj = nn.Sequential(
ConvNorm(
in_key_channels, in_key_channels * 2, kernel_size=3, bias=True, w_init_gain="relu", batch_norm=False
nn.Conv1d(
in_key_channels,
in_key_channels * 2,
kernel_size=3,
padding=1,
bias=True,
),
torch.nn.ReLU(),
ConvNorm(in_key_channels * 2, attn_channels, kernel_size=1, bias=True, batch_norm=False),
nn.Conv1d(in_key_channels * 2, attn_channels, kernel_size=1, padding=0, bias=True),
)
self.query_proj = nn.Sequential(
ConvNorm(
in_query_channels, in_query_channels * 2, kernel_size=3, bias=True, w_init_gain="relu", batch_norm=False
nn.Conv1d(
in_query_channels,
in_query_channels * 2,
kernel_size=3,
padding=1,
bias=True,
),
torch.nn.ReLU(),
ConvNorm(in_query_channels * 2, in_query_channels, kernel_size=1, bias=True, batch_norm=False),
nn.Conv1d(in_query_channels * 2, in_query_channels, kernel_size=1, padding=0, bias=True),
torch.nn.ReLU(),
ConvNorm(in_query_channels, attn_channels, kernel_size=1, bias=True, batch_norm=False),
nn.Conv1d(in_query_channels, attn_channels, kernel_size=1, padding=0, bias=True),
)
def forward(
self, queries: torch.tensor, keys: torch.tensor, mask: torch.tensor = None, attn_prior: torch.tensor = None
):
"""Forward pass of the aligner encoder.
Args:
queries (torch.tensor): query tensor.
keys (torch.tensor): key tensor.
mask (torch.tensor): uint8 binary mask for variable length entries (should be in the T2 domain).
attn_prior (torch.tensor): prior for attention matrix.
Shapes:
- queries: :math:`(B, C, T_de)`
- keys: :math:`(B, C_emb, T_en)`
@ -84,365 +87,30 @@ class AlignmentEncoder(torch.nn.Module):
return attn, attn_logprob
def mask_from_lens(lens, max_len: int = None):
if max_len is None:
max_len = lens.max()
ids = torch.arange(0, max_len, device=lens.device, dtype=lens.dtype)
mask = torch.lt(ids, lens.unsqueeze(1))
return mask
class LinearNorm(torch.nn.Module):
def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
super(LinearNorm, self).__init__()
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
torch.nn.init.xavier_uniform_(self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, x):
return self.linear_layer(x)
class ConvNorm(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=None,
dilation=1,
bias=True,
w_init_gain="linear",
batch_norm=False,
):
super(ConvNorm, self).__init__()
if padding is None:
assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = torch.nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)
self.norm = torch.nn.BatchNorm1D(out_channels) if batch_norm else None
torch.nn.init.xavier_uniform_(self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, signal):
if self.norm is None:
return self.conv(signal)
else:
return self.norm(self.conv(signal))
class ConvReLUNorm(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, dropout=0.0):
super(ConvReLUNorm, self).__init__()
self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=(kernel_size // 2))
self.norm = torch.nn.LayerNorm(out_channels)
self.dropout = torch.nn.Dropout(dropout)
def forward(self, signal):
out = F.relu(self.conv(signal))
out = self.norm(out.transpose(1, 2)).transpose(1, 2).to(signal.dtype)
return self.dropout(out)
class PositionalEmbedding(nn.Module):
def __init__(self, demb):
super(PositionalEmbedding, self).__init__()
self.demb = demb
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb))
self.register_buffer("inv_freq", inv_freq)
def forward(self, pos_seq, bsz=None):
sinusoid_inp = torch.matmul(torch.unsqueeze(pos_seq, -1), torch.unsqueeze(self.inv_freq, 0))
pos_emb = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=1)
if bsz is not None:
return pos_emb[None, :, :].expand(bsz, -1, -1)
else:
return pos_emb[None, :, :]
class PositionwiseConvFF(nn.Module):
def __init__(self, d_model, d_inner, kernel_size, dropout, pre_lnorm=False):
super(PositionwiseConvFF, self).__init__()
self.d_model = d_model
self.d_inner = d_inner
self.dropout = dropout
self.CoreNet = nn.Sequential(
nn.Conv1d(d_model, d_inner, kernel_size, 1, (kernel_size // 2)),
nn.ReLU(),
# nn.Dropout(dropout), # worse convergence
nn.Conv1d(d_inner, d_model, kernel_size, 1, (kernel_size // 2)),
nn.Dropout(dropout),
)
self.layer_norm = nn.LayerNorm(d_model)
self.pre_lnorm = pre_lnorm
def forward(self, inp):
return self._forward(inp)
def _forward(self, inp):
if self.pre_lnorm:
# layer normalization + positionwise feed-forward
core_out = inp.transpose(1, 2)
core_out = self.CoreNet(self.layer_norm(core_out).to(inp.dtype))
core_out = core_out.transpose(1, 2)
# residual connection
output = core_out + inp
else:
# positionwise feed-forward
core_out = inp.transpose(1, 2)
core_out = self.CoreNet(core_out)
core_out = core_out.transpose(1, 2)
# residual connection + layer normalization
output = self.layer_norm(inp + core_out).to(inp.dtype)
return output
class MultiHeadAttn(nn.Module):
def __init__(self, num_heads, d_model, hidden_channels_head, dropout, dropout_attn=0.1, pre_lnorm=False):
super(MultiHeadAttn, self).__init__()
self.num_heads = num_heads
self.d_model = d_model
self.hidden_channels_head = hidden_channels_head
self.scale = 1 / (hidden_channels_head ** 0.5)
self.pre_lnorm = pre_lnorm
self.qkv_net = nn.Linear(d_model, 3 * num_heads * hidden_channels_head)
self.drop = nn.Dropout(dropout)
self.dropout_attn = nn.Dropout(dropout_attn)
self.o_net = nn.Linear(num_heads * hidden_channels_head, d_model, bias=False)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, inp, attn_mask=None):
return self._forward(inp, attn_mask)
def _forward(self, inp, attn_mask=None):
residual = inp
if self.pre_lnorm:
# layer normalization
inp = self.layer_norm(inp)
num_heads, hidden_channels_head = self.num_heads, self.hidden_channels_head
head_q, head_k, head_v = torch.chunk(self.qkv_net(inp), 3, dim=2)
head_q = head_q.view(inp.size(0), inp.size(1), num_heads, hidden_channels_head)
head_k = head_k.view(inp.size(0), inp.size(1), num_heads, hidden_channels_head)
head_v = head_v.view(inp.size(0), inp.size(1), num_heads, hidden_channels_head)
q = head_q.permute(0, 2, 1, 3).reshape(-1, inp.size(1), hidden_channels_head)
k = head_k.permute(0, 2, 1, 3).reshape(-1, inp.size(1), hidden_channels_head)
v = head_v.permute(0, 2, 1, 3).reshape(-1, inp.size(1), hidden_channels_head)
attn_score = torch.bmm(q, k.transpose(1, 2))
attn_score.mul_(self.scale)
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(1).to(attn_score.dtype)
attn_mask = attn_mask.repeat(num_heads, attn_mask.size(2), 1)
attn_score.masked_fill_(attn_mask.to(torch.bool), -float("inf"))
attn_prob = F.softmax(attn_score, dim=2)
attn_prob = self.dropout_attn(attn_prob)
attn_vec = torch.bmm(attn_prob, v)
attn_vec = attn_vec.view(num_heads, inp.size(0), inp.size(1), hidden_channels_head)
attn_vec = (
attn_vec.permute(1, 2, 0, 3).contiguous().view(inp.size(0), inp.size(1), num_heads * hidden_channels_head)
)
# linear projection
attn_out = self.o_net(attn_vec)
attn_out = self.drop(attn_out)
if self.pre_lnorm:
# residual connection
output = residual + attn_out
else:
# residual connection + layer normalization
output = self.layer_norm(residual + attn_out)
output = output.to(attn_out.dtype)
return output
class TransformerLayer(nn.Module):
def __init__(
self, num_heads, hidden_channels, hidden_channels_head, hidden_channels_ffn, kernel_size, dropout, **kwargs
):
super(TransformerLayer, self).__init__()
self.dec_attn = MultiHeadAttn(num_heads, hidden_channels, hidden_channels_head, dropout, **kwargs)
self.pos_ff = PositionwiseConvFF(
hidden_channels, hidden_channels_ffn, kernel_size, dropout, pre_lnorm=kwargs.get("pre_lnorm")
)
def forward(self, dec_inp, mask=None):
output = self.dec_attn(dec_inp, attn_mask=~mask.squeeze(2))
output *= mask
output = self.pos_ff(output)
output *= mask
return output
class FFTransformer(nn.Module):
def __init__(
self,
num_layers,
num_heads,
hidden_channels,
hidden_channels_head,
hidden_channels_ffn,
kernel_size,
dropout,
dropout_attn,
dropemb=0.0,
pre_lnorm=False,
):
super(FFTransformer, self).__init__()
self.hidden_channels = hidden_channels
self.num_heads = num_heads
self.hidden_channels_head = hidden_channels_head
self.pos_emb = PositionalEmbedding(self.hidden_channels)
self.drop = nn.Dropout(dropemb)
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(
TransformerLayer(
num_heads,
hidden_channels,
hidden_channels_head,
hidden_channels_ffn,
kernel_size,
dropout,
dropout_attn=dropout_attn,
pre_lnorm=pre_lnorm,
)
)
def forward(self, x, x_lengths, conditioning=0):
mask = mask_from_lens(x_lengths).unsqueeze(2)
pos_seq = torch.arange(x.size(1), device=x.device).to(x.dtype)
pos_emb = self.pos_emb(pos_seq) * mask
if conditioning is None:
conditioning = 0
out = self.drop(x + pos_emb + conditioning)
for layer in self.layers:
out = layer(out, mask=mask)
# out = self.drop(out)
return out, mask
def regulate_len(durations, enc_out, pace=1.0, mel_max_len=None):
"""If target=None, then predicted durations are applied"""
dtype = enc_out.dtype
reps = durations.float() / pace
reps = (reps + 0.5).long()
dec_lens = reps.sum(dim=1)
max_len = dec_lens.max()
reps_cumsum = torch.cumsum(F.pad(reps, (1, 0, 0, 0), value=0.0), dim=1)[:, None, :]
reps_cumsum = reps_cumsum.to(dtype)
range_ = torch.arange(max_len).to(enc_out.device)[None, :, None]
mult = (reps_cumsum[:, :, :-1] <= range_) & (reps_cumsum[:, :, 1:] > range_)
mult = mult.to(dtype)
en_ex = torch.matmul(mult, enc_out)
if mel_max_len:
en_ex = en_ex[:, :mel_max_len]
dec_lens = torch.clamp_max(dec_lens, mel_max_len)
return en_ex, dec_lens
class TemporalPredictor(nn.Module):
"""Predicts a single float per each temporal location"""
def __init__(self, input_size, filter_size, kernel_size, dropout, num_layers=2):
super(TemporalPredictor, self).__init__()
self.layers = nn.Sequential(
*[
ConvReLUNorm(
input_size if i == 0 else filter_size, filter_size, kernel_size=kernel_size, dropout=dropout
)
for i in range(num_layers)
]
)
self.fc = nn.Linear(filter_size, 1, bias=True)
def forward(self, enc_out, enc_out_mask):
out = enc_out * enc_out_mask
out = self.layers(out.transpose(1, 2)).transpose(1, 2)
out = self.fc(out) * enc_out_mask
return out.squeeze(-1)
@dataclass
class FastPitchArgs(Coqpit):
num_chars: int = 100
num_chars: int = None
out_channels: int = 80
hidden_channels: int = 384
hidden_channels: int = 256
num_speakers: int = 0
duration_predictor_hidden_channels: int = 256
duration_predictor_dropout: float = 0.1
duration_predictor_kernel_size: int = 3
duration_predictor_dropout_p: float = 0.1
duration_predictor_num_layers: int = 2
pitch_predictor_hidden_channels: int = 256
pitch_predictor_dropout: float = 0.1
pitch_predictor_kernel_size: int = 3
pitch_predictor_dropout_p: float = 0.1
pitch_embedding_kernel_size: int = 3
pitch_predictor_num_layers: int = 2
positional_encoding: bool = True
length_scale: int = 1
encoder_type: str = "fftransformer"
encoder_params: dict = field(
default_factory=lambda: {
"hidden_channels_head": 64,
"hidden_channels_ffn": 1536,
"num_heads": 1,
"num_layers": 6,
"kernel_size": 3,
"dropout": 0.1,
"dropout_attn": 0.1,
}
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
)
decoder_type: str = "fftransformer"
decoder_params: dict = field(
default_factory=lambda: {
"hidden_channels_head": 64,
"hidden_channels_ffn": 1536,
"num_heads": 1,
"num_layers": 6,
"kernel_size": 3,
"dropout": 0.1,
"dropout_attn": 0.1,
}
default_factory=lambda: {"hidden_channels_ffn": 1024, "num_heads": 1, "num_layers": 6, "dropout_p": 0.1}
)
use_d_vector: bool = False
d_vector_dim: int = 0
@ -477,7 +145,9 @@ class FastPitch(BaseTTS):
>>> model = FastPitch(config)
"""
# pylint: disable=dangerous-default-value
def __init__(self, config: Coqpit):
super().__init__()
if "characters" in config:
@ -496,44 +166,54 @@ class FastPitch(BaseTTS):
self.length_scale = float(args.length_scale) if isinstance(args.length_scale, int) else args.length_scale
self.encoder = FFTransformer(
hidden_channels=args.hidden_channels,
**args.encoder_params,
self.emb = nn.Embedding(config.model_args.num_chars, config.model_args.hidden_channels)
self.encoder = Encoder(
config.model_args.hidden_channels,
config.model_args.hidden_channels,
config.model_args.encoder_type,
config.model_args.encoder_params,
config.model_args.d_vector_dim,
)
# if n_speakers > 1:
# self.speaker_emb = nn.Embedding(n_speakers, symbols_embedding_dim)
# else:
# self.speaker_emb = None
# self.speaker_emb_weight = speaker_emb_weight
self.emb = nn.Embedding(args.num_chars, args.hidden_channels)
if config.model_args.positional_encoding:
self.pos_encoder = PositionalEncoding(config.model_args.hidden_channels)
self.duration_predictor = TemporalPredictor(
args.hidden_channels,
filter_size=args.duration_predictor_hidden_channels,
kernel_size=args.duration_predictor_kernel_size,
dropout=args.duration_predictor_dropout_p,
num_layers=args.duration_predictor_num_layers,
self.decoder = Decoder(
config.model_args.out_channels,
config.model_args.hidden_channels,
config.model_args.decoder_type,
config.model_args.decoder_params,
)
self.decoder = FFTransformer(hidden_channels=args.hidden_channels, **args.decoder_params)
self.duration_predictor = DurationPredictor(
config.model_args.hidden_channels + config.model_args.d_vector_dim,
config.model_args.duration_predictor_hidden_channels,
config.model_args.duration_predictor_kernel_size,
config.model_args.duration_predictor_dropout_p,
)
self.pitch_predictor = TemporalPredictor(
args.hidden_channels,
filter_size=args.pitch_predictor_hidden_channels,
kernel_size=args.pitch_predictor_kernel_size,
dropout=args.pitch_predictor_dropout_p,
num_layers=args.pitch_predictor_num_layers,
self.pitch_predictor = DurationPredictor(
config.model_args.hidden_channels + config.model_args.d_vector_dim,
config.model_args.pitch_predictor_hidden_channels,
config.model_args.pitch_predictor_kernel_size,
config.model_args.pitch_predictor_dropout_p,
)
self.pitch_emb = nn.Conv1d(
1,
args.hidden_channels,
kernel_size=args.pitch_embedding_kernel_size,
padding=int((args.pitch_embedding_kernel_size - 1) / 2),
config.model_args.hidden_channels,
kernel_size=config.model_args.pitch_embedding_kernel_size,
padding=int((config.model_args.pitch_embedding_kernel_size - 1) / 2),
)
self.proj = nn.Linear(args.hidden_channels, args.out_channels, bias=True)
if config.model_args.num_speakers > 1 and not config.model_args.use_d_vector:
# speaker embedding layer
self.emb_g = nn.Embedding(config.model_args.num_speakers, config.model_args.d_vector_dim)
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
if config.model_args.d_vector_dim > 0 and config.model_args.d_vector_dim != config.model_args.hidden_channels:
self.proj_g = nn.Conv1d(config.model_args.d_vector_dim, config.model_args.hidden_channels, 1)
if args.use_aligner:
self.aligner = AlignmentEncoder(args.out_channels, args.hidden_channels)
@ -555,64 +235,109 @@ class FastPitch(BaseTTS):
"""
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype)
o_en_ex = torch.matmul(attn.transpose(1, 2), en)
return o_en_ex, attn.transpose(1, 2)
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
return o_en_ex, attn
def format_durations(self, o_dr_log, x_mask):
o_dr = (torch.exp(o_dr_log) - 1) * x_mask * self.length_scale
o_dr[o_dr < 1] = 1.0
o_dr = torch.round(o_dr)
return o_dr
@staticmethod
def _concat_speaker_embedding(o_en, g):
g_exp = g.expand(-1, -1, o_en.size(-1)) # [B, C, T_en]
o_en = torch.cat([o_en, g_exp], 1)
return o_en
def _sum_speaker_embedding(self, x, g):
# project g to decoder dim.
if hasattr(self, "proj_g"):
g = self.proj_g(g)
return x + g
def _forward_encoder(self, x, x_lengths, g=None):
if hasattr(self, "emb_g"):
g = nn.functional.normalize(self.emb_g(g)) # [B, C, 1]
if g is not None:
g = g.unsqueeze(-1)
# [B, T, C]
x_emb = self.emb(x)
# compute sequence masks
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
# encoder pass
o_en = self.encoder(torch.transpose(x_emb, 1, -1), x_mask)
# speaker conditioning for duration predictor
if g is not None:
o_en_dp = self._concat_speaker_embedding(o_en, g)
else:
o_en_dp = o_en
return o_en, o_en_dp, x_mask, g, x_emb
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
# expand o_en with durations
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
# positional encoding
if hasattr(self, "pos_encoder"):
o_en_ex = self.pos_encoder(o_en_ex, y_mask)
# speaker embedding
if g is not None:
o_en_ex = self._sum_speaker_embedding(o_en_ex, g)
# decoder pass
o_de = self.decoder(o_en_ex, y_mask, g=g)
return o_de.transpose(1, 2), attn.transpose(1, 2)
def _forward_pitch_predictor(self, o_en, x_mask, pitch=None, dr=None):
o_pitch = self.pitch_predictor(o_en, x_mask)
if pitch is not None:
avg_pitch = average_pitch(pitch, dr)
o_pitch_emb = self.pitch_emb(avg_pitch)
return o_pitch_emb, o_pitch, avg_pitch
o_pitch_emb = self.pitch_emb(o_pitch)
return o_pitch_emb, o_pitch
def _forward_aligner(self, y, embedding, x_mask, y_mask):
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), embedding.transpose(1, 2), x_mask, None)
alignment_mas = maximum_path(
alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous()
)
o_alignment_dur = torch.sum(alignment_mas, -1)
return o_alignment_dur, alignment_logprob, alignment_mas
def forward(
self, x, x_lengths, y_lengths, y=None, dr=None, pitch=None, aux_input={"d_vectors": 0, "speaker_ids": None}
):
speaker_embedding = aux_input["d_vectors"] if "d_vectors" in aux_input else 0
): # pylint: disable=unused-argument
"""
Shapes:
x: :math:`[B, T_max]`
x_lengths: :math:`[B]`
y_lengths: :math:`[B]`
y: :math:`[B, T_max2]`
dr: :math:`[B, T_max]`
g: :math:`[B, C]`
pitch: :math:`[B, 1, T]`
"""
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x.dtype)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
o_alignment_dur = None
alignment_logprob = None
alignment_mas = None
# Calculate speaker embedding
# if self.speaker_emb is None:
# speaker_embedding = 0
# else:
# speaker_embedding = self.speaker_emb(speaker).unsqueeze(1)
# speaker_embedding.mul_(self.speaker_emb_weight)
# character embedding
embedding = self.emb(x)
# Input FFT
o_en, mask_en = self.encoder(embedding, x_lengths, conditioning=speaker_embedding)
# Embedded for predictors
o_en_dr, mask_en_dr = o_en, mask_en
# Predict durations
o_dr_log = self.duration_predictor(o_en_dr, mask_en_dr)
o_en, o_en_dp, x_mask, g, x_emb = self._forward_encoder(x, x_lengths, g)
if self.config.model_args.detach_duration_predictor:
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
else:
o_dr_log = self.duration_predictor(o_en_dp, x_mask)
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
# Aligner
if self.use_aligner:
alignment_soft, alignment_logprob = self.aligner(y.transpose(1, 2), embedding.transpose(1, 2), x_mask, None)
alignment_mas = maximum_path(
alignment_soft.squeeze(1).transpose(1, 2).contiguous(), attn_mask.squeeze(1).contiguous()
)
o_alignment_dur = torch.log(1 + torch.sum(alignment_mas, -1))
avg_pitch = average_pitch(pitch, o_alignment_dur)
o_alignment_dur, alignment_logprob, alignment_mas = self._forward_aligner(y, x_emb, x_mask, y_mask)
dr = o_alignment_dur
# TODO: move this to the dataset
avg_pitch = average_pitch(pitch, dr)
# Predict pitch
o_pitch = self.pitch_predictor(o_en, mask_en).unsqueeze(1)
pitch_emb = self.pitch_emb(avg_pitch)
o_en = o_en + pitch_emb.transpose(1, 2)
# len_regulated, dec_lens = regulate_len(dr, o_en, self.length_scale, mel_max_len)
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
# Output FFT
o_de, _ = self.decoder(o_en_ex, y_lengths)
o_de = self.proj(o_de)
o_pitch_emb, o_pitch, avg_pitch = self._forward_pitch_predictor(o_en_dp, x_mask, pitch, dr)
o_en = o_en + o_pitch_emb
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g)
outputs = {
"model_outputs": o_de,
"durations_log": o_dr_log.squeeze(1),
@ -620,66 +345,55 @@ class FastPitch(BaseTTS):
"pitch": o_pitch,
"pitch_gt": avg_pitch,
"alignments": attn,
"alignment_mas": alignment_mas,
"alignment_mas": alignment_mas.transpose(1, 2),
"o_alignment_dur": o_alignment_dur,
"alignment_logprob": alignment_logprob,
}
return outputs
@torch.no_grad()
def inference(self, x, aux_input={"d_vectors": 0, "speaker_ids": None}): # pylint: disable=unused-argument
speaker_embedding = aux_input["d_vectors"] if "d_vectors" in aux_input else 0
def inference(self, x, aux_input={"d_vectors": None, "speaker_ids": None}): # pylint: disable=unused-argument
"""
Shapes:
x: [B, T_max]
x_lengths: [B]
g: [B, C]
"""
g = aux_input["d_vectors"] if "d_vectors" in aux_input else None
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
# input sequence should be greated than the max convolution size
inference_padding = 5
if x.shape[1] < 13:
inference_padding += 13 - x.shape[1]
# pad input to prevent dropping the last word
x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0)
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
# character embedding
embedding = self.emb(x)
# if self.speaker_emb is None:
# else:
# speaker = torch.ones(inputs.size(0)).long().to(inputs.device) * speaker
# spk_emb = self.speaker_emb(speaker).unsqueeze(1)
# spk_emb.mul_(self.speaker_emb_weight)
# Input FFT
o_en, mask_en = self.encoder(embedding, x_lengths, conditioning=speaker_embedding)
# Predict durations
o_dr_log = self.duration_predictor(o_en, mask_en)
o_dr = torch.clamp(torch.exp(o_dr_log) - 1, 0, self.max_duration)
o_dr = o_dr * self.length_scale
# Pitch over chars
o_pitch = self.pitch_predictor(o_en, mask_en).unsqueeze(1)
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
# duration predictor pass
o_dr_log = self.duration_predictor(o_en_dp, x_mask)
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
# pitch predictor pass
o_pitch_emb, o_pitch = self._forward_pitch_predictor(o_en_dp, x_mask)
# if pitch_transform is not None:
# if self.pitch_std[0] == 0.0:
# # XXX LJSpeech-1.1 defaults
# mean, std = 218.14, 67.24
# else:
# mean, std = self.pitch_mean[0], self.pitch_std[0]
# pitch_pred = pitch_transform(pitch_pred, mask_en.sum(dim=(1, 2)), mean, std)
o_pitch_emb = self.pitch_emb(o_pitch).transpose(1, 2)
# pitch_pred = pitch_transform(pitch_pred, enc_mask.sum(dim=(1,2)), mean, std)
# if pitch_tgt is None:
# pitch_emb = self.pitch_emb(pitch_pred.unsqueeze(1)).transpose(1, 2)
# else:
# pitch_emb = self.pitch_emb(pitch_tgt.unsqueeze(1)).transpose(1, 2)
o_en = o_en + o_pitch_emb
y_lengths = o_dr.sum(1)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x.dtype)
o_en_ex, attn = self.expand_encoder_outputs(o_en, o_dr, x_mask, y_mask)
o_de, _ = self.decoder(o_en_ex, y_lengths)
o_de = self.proj(o_de)
outputs = {"model_outputs": o_de, "alignments": attn, "pitch": o_pitch, "durations_log": o_dr_log}
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
outputs = {
"model_outputs": o_de.transpose(1, 2),
"alignments": attn,
"pitch": o_pitch,
"durations_log": o_dr_log,
}
return outputs
def train_step(self, batch: dict, criterion: nn.Module):
@ -735,8 +449,8 @@ class FastPitch(BaseTTS):
}
if self.config.model_args.use_aligner and self.training:
alignment_mas = outputs["alignment_mas"]
figures["alignment_mas"] = plot_alignment(alignment_mas, ap, output_fig=False)
alignment_mas = outputs["alignment_mas"][0].data.cpu().numpy()
figures["alignment_mas"] = plot_alignment(alignment_mas, output_fig=False)
# Sample audio
train_audio = ap.inv_melspectrogram(pred_spec.T)

View File

@ -647,29 +647,29 @@ class AudioProcessor(object):
# frame_period=1000 * self.hop_length / self.sample_rate,
# )
# f0 = pw.stonemask(x.astype(np.double), f0, t, self.sample_rate)
f0, _, _, _ = compute_yin(
x,
self.sample_rate,
self.win_length,
self.hop_length,
65 if self.mel_fmin == 0 else self.mel_fmin,
self.mel_fmax,
)
# import pyworld as pw
# f0, _ = pw.dio(x.astype(np.float64), self.sample_rate,
# frame_period=self.hop_length / self.sample_rate * 1000)
pad = int((self.win_length / self.hop_length) / 2)
f0 = [0.0] * pad + f0 + [0.0] * pad
f0 = np.array(f0, dtype=np.float32)
# f01, _, _ = librosa.pyin(
# f0, _, _, _ = compute_yin(
# x,
# fmin=65 if self.mel_fmin == 0 else self.mel_fmin,
# fmax=self.mel_fmax,
# frame_length=self.win_length,
# sr=self.sample_rate,
# fill_na=0.0,
# self.sample_rate,
# self.win_length,
# self.hop_length,
# 65 if self.mel_fmin == 0 else self.mel_fmin,
# self.mel_fmax,
# )
# # import pyworld as pw
# # f0, _ = pw.dio(x.astype(np.float64), self.sample_rate,
# # frame_period=self.hop_length / self.sample_rate * 1000)
# pad = int((self.win_length / self.hop_length) / 2)
# f0 = [0.0] * pad + f0 + [0.0] * pad
# f0 = np.array(f0, dtype=np.float32)
f0, _, _ = librosa.pyin(
x,
fmin=65 if self.mel_fmin == 0 else self.mel_fmin,
fmax=self.mel_fmax,
frame_length=self.win_length,
sr=self.sample_rate,
fill_na=0.0,
)
# f02 = librosa.yin(
# x,