mirror of https://github.com/coqui-ai/TTS.git
update Tacotron models for the trainer
parent
bdbfc95618
commit
535a458f40
|
@ -126,6 +126,7 @@ class TacotronConfig(BaseTTSConfig):
|
||||||
use_gst: bool = False
|
use_gst: bool = False
|
||||||
gst: GSTConfig = None
|
gst: GSTConfig = None
|
||||||
gst_style_input: str = None
|
gst_style_input: str = None
|
||||||
|
|
||||||
# model specific params
|
# model specific params
|
||||||
r: int = 2
|
r: int = 2
|
||||||
gradual_training: List[List[int]] = None
|
gradual_training: List[List[int]] = None
|
||||||
|
|
|
@ -113,7 +113,8 @@ class Tacotron(TacotronAbstract):
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.embeddings_per_sample:
|
if not self.embeddings_per_sample:
|
||||||
speaker_embedding_dim = 256
|
speaker_embedding_dim = 256
|
||||||
self.speaker_embedding = nn.Embedding(self.num_speakers, speaker_embedding_dim)
|
self.speaker_embedding = nn.Embedding(self.num_speakers,
|
||||||
|
speaker_embedding_dim)
|
||||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
|
|
||||||
# speaker and gst embeddings is concat in decoder input
|
# speaker and gst embeddings is concat in decoder input
|
||||||
|
@ -144,7 +145,8 @@ class Tacotron(TacotronAbstract):
|
||||||
separate_stopnet,
|
separate_stopnet,
|
||||||
)
|
)
|
||||||
self.postnet = PostCBHG(decoder_output_dim)
|
self.postnet = PostCBHG(decoder_output_dim)
|
||||||
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, postnet_output_dim)
|
self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2,
|
||||||
|
postnet_output_dim)
|
||||||
|
|
||||||
# setup prenet dropout
|
# setup prenet dropout
|
||||||
self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference
|
self.decoder.prenet.dropout_at_inference = prenet_dropout_at_inference
|
||||||
|
@ -181,93 +183,203 @@ class Tacotron(TacotronAbstract):
|
||||||
separate_stopnet,
|
separate_stopnet,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, characters, text_lengths, mel_specs, mel_lengths=None, speaker_ids=None, speaker_embeddings=None):
|
def forward(self,
|
||||||
|
text,
|
||||||
|
text_lengths,
|
||||||
|
mel_specs=None,
|
||||||
|
mel_lengths=None,
|
||||||
|
cond_input=None):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
characters: [B, T_in]
|
text: [B, T_in]
|
||||||
text_lengths: [B]
|
text_lengths: [B]
|
||||||
mel_specs: [B, T_out, C]
|
mel_specs: [B, T_out, C]
|
||||||
mel_lengths: [B]
|
mel_lengths: [B]
|
||||||
speaker_ids: [B, 1]
|
cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C]
|
||||||
speaker_embeddings: [B, C]
|
|
||||||
"""
|
"""
|
||||||
|
outputs = {
|
||||||
|
'alignments_backward': None,
|
||||||
|
'decoder_outputs_backward': None
|
||||||
|
}
|
||||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||||
# B x T_in x embed_dim
|
# B x T_in x embed_dim
|
||||||
inputs = self.embedding(characters)
|
inputs = self.embedding(text)
|
||||||
# B x T_in x encoder_in_features
|
# B x T_in x encoder_in_features
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
# sequence masking
|
# sequence masking
|
||||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(
|
||||||
|
encoder_outputs)
|
||||||
# global style token
|
# global style token
|
||||||
if self.gst and self.use_gst:
|
if self.gst and self.use_gst:
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, speaker_embeddings)
|
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs,
|
||||||
|
cond_input['x_vectors'])
|
||||||
# speaker embedding
|
# speaker embedding
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.embeddings_per_sample:
|
if not self.embeddings_per_sample:
|
||||||
# B x 1 x speaker_embed_dim
|
# B x 1 x speaker_embed_dim
|
||||||
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
speaker_embeddings = self.speaker_embedding(cond_input['speaker_ids'])[:,
|
||||||
|
None]
|
||||||
else:
|
else:
|
||||||
# B x 1 x speaker_embed_dim
|
# B x 1 x speaker_embed_dim
|
||||||
speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1)
|
speaker_embeddings = torch.unsqueeze(cond_input['x_vectors'], 1)
|
||||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
encoder_outputs = self._concat_speaker_embedding(
|
||||||
|
encoder_outputs, speaker_embeddings)
|
||||||
# decoder_outputs: B x decoder_in_features x T_out
|
# decoder_outputs: B x decoder_in_features x T_out
|
||||||
# alignments: B x T_in x encoder_in_features
|
# alignments: B x T_in x encoder_in_features
|
||||||
# stop_tokens: B x T_in
|
# stop_tokens: B x T_in
|
||||||
decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask)
|
decoder_outputs, alignments, stop_tokens = self.decoder(
|
||||||
|
encoder_outputs, mel_specs, input_mask)
|
||||||
# sequence masking
|
# sequence masking
|
||||||
if output_mask is not None:
|
if output_mask is not None:
|
||||||
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
|
decoder_outputs = decoder_outputs * output_mask.unsqueeze(
|
||||||
|
1).expand_as(decoder_outputs)
|
||||||
# B x T_out x decoder_in_features
|
# B x T_out x decoder_in_features
|
||||||
postnet_outputs = self.postnet(decoder_outputs)
|
postnet_outputs = self.postnet(decoder_outputs)
|
||||||
# sequence masking
|
# sequence masking
|
||||||
if output_mask is not None:
|
if output_mask is not None:
|
||||||
postnet_outputs = postnet_outputs * output_mask.unsqueeze(2).expand_as(postnet_outputs)
|
postnet_outputs = postnet_outputs * output_mask.unsqueeze(
|
||||||
|
2).expand_as(postnet_outputs)
|
||||||
# B x T_out x posnet_dim
|
# B x T_out x posnet_dim
|
||||||
postnet_outputs = self.last_linear(postnet_outputs)
|
postnet_outputs = self.last_linear(postnet_outputs)
|
||||||
# B x T_out x decoder_in_features
|
# B x T_out x decoder_in_features
|
||||||
decoder_outputs = decoder_outputs.transpose(1, 2).contiguous()
|
decoder_outputs = decoder_outputs.transpose(1, 2).contiguous()
|
||||||
if self.bidirectional_decoder:
|
if self.bidirectional_decoder:
|
||||||
decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
|
decoder_outputs_backward, alignments_backward = self._backward_pass(
|
||||||
return (
|
mel_specs, encoder_outputs, input_mask)
|
||||||
decoder_outputs,
|
outputs['alignments_backward'] = alignments_backward
|
||||||
postnet_outputs,
|
outputs['decoder_outputs_backward'] = decoder_outputs_backward
|
||||||
alignments,
|
|
||||||
stop_tokens,
|
|
||||||
decoder_outputs_backward,
|
|
||||||
alignments_backward,
|
|
||||||
)
|
|
||||||
if self.double_decoder_consistency:
|
if self.double_decoder_consistency:
|
||||||
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
|
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
|
||||||
mel_specs, encoder_outputs, alignments, input_mask
|
mel_specs, encoder_outputs, alignments, input_mask)
|
||||||
)
|
outputs['alignments_backward'] = alignments_backward
|
||||||
return (
|
outputs['decoder_outputs_backward'] = decoder_outputs_backward
|
||||||
decoder_outputs,
|
outputs.update({
|
||||||
postnet_outputs,
|
'postnet_outputs': postnet_outputs,
|
||||||
alignments,
|
'decoder_outputs': decoder_outputs,
|
||||||
stop_tokens,
|
'alignments': alignments,
|
||||||
decoder_outputs_backward,
|
'stop_tokens': stop_tokens
|
||||||
alignments_backward,
|
})
|
||||||
)
|
return outputs
|
||||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self, characters, speaker_ids=None, style_mel=None, speaker_embeddings=None):
|
def inference(self,
|
||||||
inputs = self.embedding(characters)
|
text_input,
|
||||||
|
cond_input=None):
|
||||||
|
inputs = self.embedding(text_input)
|
||||||
encoder_outputs = self.encoder(inputs)
|
encoder_outputs = self.encoder(inputs)
|
||||||
if self.gst and self.use_gst:
|
if self.gst and self.use_gst:
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings)
|
encoder_outputs = self.compute_gst(encoder_outputs, cond_input['style_mel'],
|
||||||
|
cond_input['x_vectors'])
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.embeddings_per_sample:
|
if not self.embeddings_per_sample:
|
||||||
# B x 1 x speaker_embed_dim
|
# B x 1 x speaker_embed_dim
|
||||||
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
speaker_embeddings = self.speaker_embedding(cond_input['speaker_ids'])[:, None]
|
||||||
else:
|
else:
|
||||||
# B x 1 x speaker_embed_dim
|
# B x 1 x speaker_embed_dim
|
||||||
speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1)
|
speaker_embeddings = torch.unsqueeze(cond_input['x_vectors'], 1)
|
||||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
encoder_outputs = self._concat_speaker_embedding(
|
||||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs)
|
encoder_outputs, speaker_embeddings)
|
||||||
|
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||||
|
encoder_outputs)
|
||||||
postnet_outputs = self.postnet(decoder_outputs)
|
postnet_outputs = self.postnet(decoder_outputs)
|
||||||
postnet_outputs = self.last_linear(postnet_outputs)
|
postnet_outputs = self.last_linear(postnet_outputs)
|
||||||
decoder_outputs = decoder_outputs.transpose(1, 2)
|
decoder_outputs = decoder_outputs.transpose(1, 2)
|
||||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
outputs = {
|
||||||
|
'postnet_outputs': postnet_outputs,
|
||||||
|
'decoder_outputs': decoder_outputs,
|
||||||
|
'alignments': alignments,
|
||||||
|
'stop_tokens': stop_tokens
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def train_step(self, batch, criterion):
|
||||||
|
"""Perform a single training step by fetching the right set if samples from the batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch ([type]): [description]
|
||||||
|
criterion ([type]): [description]
|
||||||
|
"""
|
||||||
|
text_input = batch['text_input']
|
||||||
|
text_lengths = batch['text_lengths']
|
||||||
|
mel_input = batch['mel_input']
|
||||||
|
mel_lengths = batch['mel_lengths']
|
||||||
|
linear_input = batch['linear_input']
|
||||||
|
stop_targets = batch['stop_targets']
|
||||||
|
speaker_ids = batch['speaker_ids']
|
||||||
|
x_vectors = batch['x_vectors']
|
||||||
|
|
||||||
|
# forward pass model
|
||||||
|
outputs = self.forward(text_input,
|
||||||
|
text_lengths,
|
||||||
|
mel_input,
|
||||||
|
mel_lengths,
|
||||||
|
cond_input={
|
||||||
|
'speaker_ids': speaker_ids,
|
||||||
|
'x_vectors': x_vectors
|
||||||
|
})
|
||||||
|
|
||||||
|
# set the [alignment] lengths wrt reduction factor for guided attention
|
||||||
|
if mel_lengths.max() % self.decoder.r != 0:
|
||||||
|
alignment_lengths = (
|
||||||
|
mel_lengths +
|
||||||
|
(self.decoder.r -
|
||||||
|
(mel_lengths.max() % self.decoder.r))) // self.decoder.r
|
||||||
|
else:
|
||||||
|
alignment_lengths = mel_lengths // self.decoder.r
|
||||||
|
|
||||||
|
cond_input = {'speaker_ids': speaker_ids, 'x_vectors': x_vectors}
|
||||||
|
outputs = self.forward(text_input, text_lengths, mel_input,
|
||||||
|
mel_lengths, cond_input)
|
||||||
|
|
||||||
|
# compute loss
|
||||||
|
loss_dict = criterion(
|
||||||
|
outputs['postnet_outputs'],
|
||||||
|
outputs['decoder_outputs'],
|
||||||
|
mel_input,
|
||||||
|
linear_input,
|
||||||
|
outputs['stop_tokens'],
|
||||||
|
stop_targets,
|
||||||
|
mel_lengths,
|
||||||
|
outputs['decoder_outputs_backward'],
|
||||||
|
outputs['alignments'],
|
||||||
|
alignment_lengths,
|
||||||
|
outputs['alignments_backward'],
|
||||||
|
text_lengths,
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute alignment error (the lower the better )
|
||||||
|
align_error = 1 - alignment_diagonal_score(outputs['alignments'])
|
||||||
|
loss_dict["align_error"] = align_error
|
||||||
|
return outputs, loss_dict
|
||||||
|
|
||||||
|
def train_log(self, ap, batch, outputs):
|
||||||
|
postnet_outputs = outputs['postnet_outputs']
|
||||||
|
alignments = outputs['alignments']
|
||||||
|
alignments_backward = outputs['alignments_backward']
|
||||||
|
mel_input = batch['mel_input']
|
||||||
|
|
||||||
|
pred_spec = postnet_outputs[0].data.cpu().numpy()
|
||||||
|
gt_spec = mel_input[0].data.cpu().numpy()
|
||||||
|
align_img = alignments[0].data.cpu().numpy()
|
||||||
|
|
||||||
|
figures = {
|
||||||
|
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
||||||
|
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||||
|
"alignment": plot_alignment(align_img, output_fig=False),
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.bidirectional_decoder or self.double_decoder_consistency:
|
||||||
|
figures["alignment_backward"] = plot_alignment(
|
||||||
|
alignments_backward[0].data.cpu().numpy(), output_fig=False)
|
||||||
|
|
||||||
|
# Sample audio
|
||||||
|
train_audio = ap.inv_spectrogram(pred_spec.T)
|
||||||
|
return figures, train_audio
|
||||||
|
|
||||||
|
def eval_step(self, batch, criterion):
|
||||||
|
return self.train_step(batch, criterion)
|
||||||
|
|
||||||
|
def eval_log(self, ap, batch, outputs):
|
||||||
|
return self.train_log(ap, batch, outputs)
|
|
@ -1,12 +1,15 @@
|
||||||
|
# coding: utf-8
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
from TTS.tts.utils.measures import alignment_diagonal_score
|
||||||
|
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
|
||||||
from TTS.tts.layers.tacotron.gst_layers import GST
|
from TTS.tts.layers.tacotron.gst_layers import GST
|
||||||
from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
|
from TTS.tts.layers.tacotron.tacotron2 import Decoder, Encoder, Postnet
|
||||||
from TTS.tts.models.tacotron_abstract import TacotronAbstract
|
from TTS.tts.models.tacotron_abstract import TacotronAbstract
|
||||||
|
|
||||||
|
|
||||||
# TODO: match function arguments with tacotron
|
|
||||||
class Tacotron2(TacotronAbstract):
|
class Tacotron2(TacotronAbstract):
|
||||||
"""Tacotron2 as in https://arxiv.org/abs/1712.05884
|
"""Tacotron2 as in https://arxiv.org/abs/1712.05884
|
||||||
|
|
||||||
|
@ -43,69 +46,52 @@ class Tacotron2(TacotronAbstract):
|
||||||
speaker_embedding_dim (int, optional): external speaker conditioning vector channels. Defaults to None.
|
speaker_embedding_dim (int, optional): external speaker conditioning vector channels. Defaults to None.
|
||||||
use_gst (bool, optional): enable/disable Global style token module.
|
use_gst (bool, optional): enable/disable Global style token module.
|
||||||
gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None.
|
gst (Coqpit, optional): Coqpit to initialize the GST module. If `None`, GST is disabled. Defaults to None.
|
||||||
|
gradual_trainin (List): Gradual training schedule. If None or `[]`, no gradual training is used.
|
||||||
|
Defaults to `[]`.
|
||||||
"""
|
"""
|
||||||
|
def __init__(self,
|
||||||
def __init__(
|
num_chars,
|
||||||
self,
|
num_speakers,
|
||||||
num_chars,
|
r,
|
||||||
num_speakers,
|
postnet_output_dim=80,
|
||||||
r,
|
decoder_output_dim=80,
|
||||||
postnet_output_dim=80,
|
attn_type="original",
|
||||||
decoder_output_dim=80,
|
attn_win=False,
|
||||||
attn_type="original",
|
attn_norm="softmax",
|
||||||
attn_win=False,
|
prenet_type="original",
|
||||||
attn_norm="softmax",
|
prenet_dropout=True,
|
||||||
prenet_type="original",
|
prenet_dropout_at_inference=False,
|
||||||
prenet_dropout=True,
|
forward_attn=False,
|
||||||
prenet_dropout_at_inference=False,
|
trans_agent=False,
|
||||||
forward_attn=False,
|
forward_attn_mask=False,
|
||||||
trans_agent=False,
|
location_attn=True,
|
||||||
forward_attn_mask=False,
|
attn_K=5,
|
||||||
location_attn=True,
|
separate_stopnet=True,
|
||||||
attn_K=5,
|
bidirectional_decoder=False,
|
||||||
separate_stopnet=True,
|
double_decoder_consistency=False,
|
||||||
bidirectional_decoder=False,
|
ddc_r=None,
|
||||||
double_decoder_consistency=False,
|
encoder_in_features=512,
|
||||||
ddc_r=None,
|
decoder_in_features=512,
|
||||||
encoder_in_features=512,
|
speaker_embedding_dim=None,
|
||||||
decoder_in_features=512,
|
use_gst=False,
|
||||||
speaker_embedding_dim=None,
|
gst=None,
|
||||||
use_gst=False,
|
gradual_training=[]):
|
||||||
gst=None,
|
super().__init__(num_chars, num_speakers, r, postnet_output_dim,
|
||||||
):
|
decoder_output_dim, attn_type, attn_win, attn_norm,
|
||||||
super().__init__(
|
prenet_type, prenet_dropout,
|
||||||
num_chars,
|
prenet_dropout_at_inference, forward_attn,
|
||||||
num_speakers,
|
trans_agent, forward_attn_mask, location_attn, attn_K,
|
||||||
r,
|
separate_stopnet, bidirectional_decoder,
|
||||||
postnet_output_dim,
|
double_decoder_consistency, ddc_r,
|
||||||
decoder_output_dim,
|
encoder_in_features, decoder_in_features,
|
||||||
attn_type,
|
speaker_embedding_dim, use_gst, gst, gradual_training)
|
||||||
attn_win,
|
|
||||||
attn_norm,
|
|
||||||
prenet_type,
|
|
||||||
prenet_dropout,
|
|
||||||
prenet_dropout_at_inference,
|
|
||||||
forward_attn,
|
|
||||||
trans_agent,
|
|
||||||
forward_attn_mask,
|
|
||||||
location_attn,
|
|
||||||
attn_K,
|
|
||||||
separate_stopnet,
|
|
||||||
bidirectional_decoder,
|
|
||||||
double_decoder_consistency,
|
|
||||||
ddc_r,
|
|
||||||
encoder_in_features,
|
|
||||||
decoder_in_features,
|
|
||||||
speaker_embedding_dim,
|
|
||||||
use_gst,
|
|
||||||
gst,
|
|
||||||
)
|
|
||||||
|
|
||||||
# speaker embedding layer
|
# speaker embedding layer
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.embeddings_per_sample:
|
if not self.embeddings_per_sample:
|
||||||
speaker_embedding_dim = 512
|
speaker_embedding_dim = 512
|
||||||
self.speaker_embedding = nn.Embedding(self.num_speakers, speaker_embedding_dim)
|
self.speaker_embedding = nn.Embedding(self.num_speakers,
|
||||||
|
speaker_embedding_dim)
|
||||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||||
|
|
||||||
# speaker and gst embeddings is concat in decoder input
|
# speaker and gst embeddings is concat in decoder input
|
||||||
|
@ -176,16 +162,24 @@ class Tacotron2(TacotronAbstract):
|
||||||
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
mel_outputs_postnet = mel_outputs_postnet.transpose(1, 2)
|
||||||
return mel_outputs, mel_outputs_postnet, alignments
|
return mel_outputs, mel_outputs_postnet, alignments
|
||||||
|
|
||||||
def forward(self, text, text_lengths, mel_specs=None, mel_lengths=None, speaker_ids=None, speaker_embeddings=None):
|
def forward(self,
|
||||||
|
text,
|
||||||
|
text_lengths,
|
||||||
|
mel_specs=None,
|
||||||
|
mel_lengths=None,
|
||||||
|
cond_input=None):
|
||||||
"""
|
"""
|
||||||
Shapes:
|
Shapes:
|
||||||
text: [B, T_in]
|
text: [B, T_in]
|
||||||
text_lengths: [B]
|
text_lengths: [B]
|
||||||
mel_specs: [B, T_out, C]
|
mel_specs: [B, T_out, C]
|
||||||
mel_lengths: [B]
|
mel_lengths: [B]
|
||||||
speaker_ids: [B, 1]
|
cond_input: 'speaker_ids': [B, 1] and 'x_vectors':[B, C]
|
||||||
speaker_embeddings: [B, C]
|
|
||||||
"""
|
"""
|
||||||
|
outputs = {
|
||||||
|
'alignments_backward': None,
|
||||||
|
'decoder_outputs_backward': None
|
||||||
|
}
|
||||||
# compute mask for padding
|
# compute mask for padding
|
||||||
# B x T_in_max (boolean)
|
# B x T_in_max (boolean)
|
||||||
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
input_mask, output_mask = self.compute_masks(text_lengths, mel_lengths)
|
||||||
|
@ -195,94 +189,176 @@ class Tacotron2(TacotronAbstract):
|
||||||
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
||||||
if self.gst and self.use_gst:
|
if self.gst and self.use_gst:
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs, speaker_embeddings)
|
encoder_outputs = self.compute_gst(encoder_outputs, mel_specs,
|
||||||
|
cond_input['x_vectors'])
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.embeddings_per_sample:
|
if not self.embeddings_per_sample:
|
||||||
# B x 1 x speaker_embed_dim
|
# B x 1 x speaker_embed_dim
|
||||||
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
speaker_embeddings = self.speaker_embedding(cond_input['speaker_ids'])[:,
|
||||||
|
None]
|
||||||
else:
|
else:
|
||||||
# B x 1 x speaker_embed_dim
|
# B x 1 x speaker_embed_dim
|
||||||
speaker_embeddings = torch.unsqueeze(speaker_embeddings, 1)
|
speaker_embeddings = torch.unsqueeze(cond_input['x_vectors'], 1)
|
||||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
encoder_outputs = self._concat_speaker_embedding(
|
||||||
|
encoder_outputs, speaker_embeddings)
|
||||||
|
|
||||||
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(encoder_outputs)
|
encoder_outputs = encoder_outputs * input_mask.unsqueeze(2).expand_as(
|
||||||
|
encoder_outputs)
|
||||||
|
|
||||||
# B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r
|
# B x mel_dim x T_out -- B x T_out//r x T_in -- B x T_out//r
|
||||||
decoder_outputs, alignments, stop_tokens = self.decoder(encoder_outputs, mel_specs, input_mask)
|
decoder_outputs, alignments, stop_tokens = self.decoder(
|
||||||
|
encoder_outputs, mel_specs, input_mask)
|
||||||
# sequence masking
|
# sequence masking
|
||||||
if mel_lengths is not None:
|
if mel_lengths is not None:
|
||||||
decoder_outputs = decoder_outputs * output_mask.unsqueeze(1).expand_as(decoder_outputs)
|
decoder_outputs = decoder_outputs * output_mask.unsqueeze(
|
||||||
|
1).expand_as(decoder_outputs)
|
||||||
# B x mel_dim x T_out
|
# B x mel_dim x T_out
|
||||||
postnet_outputs = self.postnet(decoder_outputs)
|
postnet_outputs = self.postnet(decoder_outputs)
|
||||||
postnet_outputs = decoder_outputs + postnet_outputs
|
postnet_outputs = decoder_outputs + postnet_outputs
|
||||||
# sequence masking
|
# sequence masking
|
||||||
if output_mask is not None:
|
if output_mask is not None:
|
||||||
postnet_outputs = postnet_outputs * output_mask.unsqueeze(1).expand_as(postnet_outputs)
|
postnet_outputs = postnet_outputs * output_mask.unsqueeze(
|
||||||
|
1).expand_as(postnet_outputs)
|
||||||
# B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in
|
# B x T_out x mel_dim -- B x T_out x mel_dim -- B x T_out//r x T_in
|
||||||
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments)
|
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
|
||||||
|
decoder_outputs, postnet_outputs, alignments)
|
||||||
if self.bidirectional_decoder:
|
if self.bidirectional_decoder:
|
||||||
decoder_outputs_backward, alignments_backward = self._backward_pass(mel_specs, encoder_outputs, input_mask)
|
decoder_outputs_backward, alignments_backward = self._backward_pass(
|
||||||
return (
|
mel_specs, encoder_outputs, input_mask)
|
||||||
decoder_outputs,
|
outputs['alignments_backward'] = alignments_backward
|
||||||
postnet_outputs,
|
outputs['decoder_outputs_backward'] = decoder_outputs_backward
|
||||||
alignments,
|
|
||||||
stop_tokens,
|
|
||||||
decoder_outputs_backward,
|
|
||||||
alignments_backward,
|
|
||||||
)
|
|
||||||
if self.double_decoder_consistency:
|
if self.double_decoder_consistency:
|
||||||
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
|
decoder_outputs_backward, alignments_backward = self._coarse_decoder_pass(
|
||||||
mel_specs, encoder_outputs, alignments, input_mask
|
mel_specs, encoder_outputs, alignments, input_mask)
|
||||||
)
|
outputs['alignments_backward'] = alignments_backward
|
||||||
return (
|
outputs['decoder_outputs_backward'] = decoder_outputs_backward
|
||||||
decoder_outputs,
|
outputs.update({
|
||||||
postnet_outputs,
|
'postnet_outputs': postnet_outputs,
|
||||||
alignments,
|
'decoder_outputs': decoder_outputs,
|
||||||
stop_tokens,
|
'alignments': alignments,
|
||||||
decoder_outputs_backward,
|
'stop_tokens': stop_tokens
|
||||||
alignments_backward,
|
})
|
||||||
)
|
return outputs
|
||||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def inference(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None):
|
def inference(self, text, cond_input=None):
|
||||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
embedded_inputs = self.embedding(text).transpose(1, 2)
|
||||||
encoder_outputs = self.encoder.inference(embedded_inputs)
|
encoder_outputs = self.encoder.inference(embedded_inputs)
|
||||||
|
|
||||||
if self.gst and self.use_gst:
|
if self.gst and self.use_gst:
|
||||||
# B x gst_dim
|
# B x gst_dim
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings)
|
encoder_outputs = self.compute_gst(encoder_outputs, cond_input['style_mel'],
|
||||||
|
cond_input['x_vectors'])
|
||||||
if self.num_speakers > 1:
|
if self.num_speakers > 1:
|
||||||
if not self.embeddings_per_sample:
|
if not self.embeddings_per_sample:
|
||||||
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
x_vector = self.speaker_embedding(cond_input['speaker_ids'])[:, None]
|
||||||
speaker_embeddings = torch.unsqueeze(speaker_embeddings, 0).transpose(1, 2)
|
x_vector = torch.unsqueeze(x_vector, 0).transpose(1, 2)
|
||||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
else:
|
||||||
|
x_vector = cond_input
|
||||||
|
|
||||||
decoder_outputs, alignments, stop_tokens = self.decoder.inference(encoder_outputs)
|
encoder_outputs = self._concat_speaker_embedding(
|
||||||
|
encoder_outputs, x_vector)
|
||||||
|
|
||||||
|
decoder_outputs, alignments, stop_tokens = self.decoder.inference(
|
||||||
|
encoder_outputs)
|
||||||
postnet_outputs = self.postnet(decoder_outputs)
|
postnet_outputs = self.postnet(decoder_outputs)
|
||||||
postnet_outputs = decoder_outputs + postnet_outputs
|
postnet_outputs = decoder_outputs + postnet_outputs
|
||||||
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(decoder_outputs, postnet_outputs, alignments)
|
decoder_outputs, postnet_outputs, alignments = self.shape_outputs(
|
||||||
return decoder_outputs, postnet_outputs, alignments, stop_tokens
|
decoder_outputs, postnet_outputs, alignments)
|
||||||
|
outputs = {
|
||||||
|
'postnet_outputs': postnet_outputs,
|
||||||
|
'decoder_outputs': decoder_outputs,
|
||||||
|
'alignments': alignments,
|
||||||
|
'stop_tokens': stop_tokens
|
||||||
|
}
|
||||||
|
return outputs
|
||||||
|
|
||||||
def inference_truncated(self, text, speaker_ids=None, style_mel=None, speaker_embeddings=None):
|
def train_step(self, batch, criterion):
|
||||||
|
"""Perform a single training step by fetching the right set if samples from the batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch ([type]): [description]
|
||||||
|
criterion ([type]): [description]
|
||||||
"""
|
"""
|
||||||
Preserve model states for continuous inference
|
text_input = batch['text_input']
|
||||||
"""
|
text_lengths = batch['text_lengths']
|
||||||
embedded_inputs = self.embedding(text).transpose(1, 2)
|
mel_input = batch['mel_input']
|
||||||
encoder_outputs = self.encoder.inference_truncated(embedded_inputs)
|
mel_lengths = batch['mel_lengths']
|
||||||
|
linear_input = batch['linear_input']
|
||||||
|
stop_targets = batch['stop_targets']
|
||||||
|
speaker_ids = batch['speaker_ids']
|
||||||
|
x_vectors = batch['x_vectors']
|
||||||
|
|
||||||
if self.gst:
|
# forward pass model
|
||||||
# B x gst_dim
|
outputs = self.forward(text_input,
|
||||||
encoder_outputs = self.compute_gst(encoder_outputs, style_mel, speaker_embeddings)
|
text_lengths,
|
||||||
|
mel_input,
|
||||||
|
mel_lengths,
|
||||||
|
cond_input={
|
||||||
|
'speaker_ids': speaker_ids,
|
||||||
|
'x_vectors': x_vectors
|
||||||
|
})
|
||||||
|
|
||||||
if self.num_speakers > 1:
|
# set the [alignment] lengths wrt reduction factor for guided attention
|
||||||
if not self.embeddings_per_sample:
|
if mel_lengths.max() % self.decoder.r != 0:
|
||||||
speaker_embeddings = self.speaker_embedding(speaker_ids)[:, None]
|
alignment_lengths = (
|
||||||
speaker_embeddings = torch.unsqueeze(speaker_embeddings, 0).transpose(1, 2)
|
mel_lengths +
|
||||||
encoder_outputs = self._concat_speaker_embedding(encoder_outputs, speaker_embeddings)
|
(self.decoder.r -
|
||||||
|
(mel_lengths.max() % self.decoder.r))) // self.decoder.r
|
||||||
|
else:
|
||||||
|
alignment_lengths = mel_lengths // self.decoder.r
|
||||||
|
|
||||||
mel_outputs, alignments, stop_tokens = self.decoder.inference_truncated(encoder_outputs)
|
cond_input = {'speaker_ids': speaker_ids, 'x_vectors': x_vectors}
|
||||||
mel_outputs_postnet = self.postnet(mel_outputs)
|
outputs = self.forward(text_input, text_lengths, mel_input,
|
||||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
mel_lengths, cond_input)
|
||||||
mel_outputs, mel_outputs_postnet, alignments = self.shape_outputs(mel_outputs, mel_outputs_postnet, alignments)
|
|
||||||
return mel_outputs, mel_outputs_postnet, alignments, stop_tokens
|
# compute loss
|
||||||
|
loss_dict = criterion(
|
||||||
|
outputs['model_outputs'],
|
||||||
|
outputs['decoder_outputs'],
|
||||||
|
mel_input,
|
||||||
|
linear_input,
|
||||||
|
outputs['stop_tokens'],
|
||||||
|
stop_targets,
|
||||||
|
mel_lengths,
|
||||||
|
outputs['decoder_outputs_backward'],
|
||||||
|
outputs['alignments'],
|
||||||
|
alignment_lengths,
|
||||||
|
outputs['alignments_backward'],
|
||||||
|
text_lengths,
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute alignment error (the lower the better )
|
||||||
|
align_error = 1 - alignment_diagonal_score(outputs['alignments'])
|
||||||
|
loss_dict["align_error"] = align_error
|
||||||
|
return outputs, loss_dict
|
||||||
|
|
||||||
|
def train_log(self, ap, batch, outputs):
|
||||||
|
postnet_outputs = outputs['model_outputs']
|
||||||
|
alignments = outputs['alignments']
|
||||||
|
alignments_backward = outputs['alignments_backward']
|
||||||
|
mel_input = batch['mel_input']
|
||||||
|
|
||||||
|
pred_spec = postnet_outputs[0].data.cpu().numpy()
|
||||||
|
gt_spec = mel_input[0].data.cpu().numpy()
|
||||||
|
align_img = alignments[0].data.cpu().numpy()
|
||||||
|
|
||||||
|
figures = {
|
||||||
|
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
|
||||||
|
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
|
||||||
|
"alignment": plot_alignment(align_img, output_fig=False),
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.bidirectional_decoder or self.double_decoder_consistency:
|
||||||
|
figures["alignment_backward"] = plot_alignment(
|
||||||
|
alignments_backward[0].data.cpu().numpy(), output_fig=False)
|
||||||
|
|
||||||
|
# Sample audio
|
||||||
|
train_audio = ap.inv_melspectrogram(pred_spec.T)
|
||||||
|
return figures, train_audio
|
||||||
|
|
||||||
|
def eval_step(self, batch, criterion):
|
||||||
|
return self.train_step(batch, criterion)
|
||||||
|
|
||||||
|
def eval_log(self, ap, batch, outputs):
|
||||||
|
return self.train_log(ap, batch, outputs)
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
import copy
|
import copy
|
||||||
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from TTS.tts.utils.generic_utils import sequence_mask
|
from TTS.tts.utils.data import sequence_mask
|
||||||
|
from TTS.utils.training import gradual_training_scheduler
|
||||||
|
|
||||||
|
|
||||||
class TacotronAbstract(ABC, nn.Module):
|
class TacotronAbstract(ABC, nn.Module):
|
||||||
|
@ -35,6 +37,7 @@ class TacotronAbstract(ABC, nn.Module):
|
||||||
speaker_embedding_dim=None,
|
speaker_embedding_dim=None,
|
||||||
use_gst=False,
|
use_gst=False,
|
||||||
gst=None,
|
gst=None,
|
||||||
|
gradual_training=[]
|
||||||
):
|
):
|
||||||
"""Abstract Tacotron class"""
|
"""Abstract Tacotron class"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -63,6 +66,7 @@ class TacotronAbstract(ABC, nn.Module):
|
||||||
self.encoder_in_features = encoder_in_features
|
self.encoder_in_features = encoder_in_features
|
||||||
self.decoder_in_features = decoder_in_features
|
self.decoder_in_features = decoder_in_features
|
||||||
self.speaker_embedding_dim = speaker_embedding_dim
|
self.speaker_embedding_dim = speaker_embedding_dim
|
||||||
|
self.gradual_training = gradual_training
|
||||||
|
|
||||||
# layers
|
# layers
|
||||||
self.embedding = None
|
self.embedding = None
|
||||||
|
@ -216,3 +220,23 @@ class TacotronAbstract(ABC, nn.Module):
|
||||||
speaker_embeddings_ = speaker_embeddings.expand(outputs.size(0), outputs.size(1), -1)
|
speaker_embeddings_ = speaker_embeddings.expand(outputs.size(0), outputs.size(1), -1)
|
||||||
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
|
outputs = torch.cat([outputs, speaker_embeddings_], dim=-1)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
#############################
|
||||||
|
# CALLBACKS
|
||||||
|
#############################
|
||||||
|
|
||||||
|
def on_epoch_start(self, trainer):
|
||||||
|
"""Callback for setting values wrt gradual training schedule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
trainer (TrainerTTS): TTS trainer object that is used to train this model.
|
||||||
|
"""
|
||||||
|
if self.gradual_training:
|
||||||
|
r, trainer.config.batch_size = gradual_training_scheduler(trainer.total_steps_done, trainer.config)
|
||||||
|
trainer.config.r = r
|
||||||
|
self.decoder.set_r(r)
|
||||||
|
if trainer.config.bidirectional_decoder:
|
||||||
|
trainer.model.decoder_backward.set_r(r)
|
||||||
|
trainer.train_loader = trainer.setup_train_dataloader(self.ap, self.model.decoder.r, verbose=True)
|
||||||
|
trainer.eval_loader = trainer.setup_eval_dataloder(self.ap, self.model.decoder.r)
|
||||||
|
logging.info(f"\n > Number of output frames: {self.decoder.r}")
|
||||||
|
|
Loading…
Reference in New Issue