update `speedy_speech.py` model for trainer

pull/602/head
Eren Gölge 2021-05-26 16:03:24 +02:00
parent 843b3ba960
commit f121b0ff5d
1 changed files with 121 additions and 18 deletions

View File

@ -3,6 +3,9 @@ from torch import nn
from TTS.tts.layers.feed_forward.decoder import Decoder
from TTS.tts.layers.feed_forward.duration_predictor import DurationPredictor
from TTS.tts.utils.measures import alignment_diagonal_score
from TTS.tts.utils.visual import plot_alignment, plot_spectrogram
from TTS.utils.audio import AudioProcessor
from TTS.tts.layers.feed_forward.encoder import Encoder
from TTS.tts.layers.generic.pos_encoding import PositionalEncoding
from TTS.tts.layers.glow_tts.monotonic_align import generate_path
@ -46,7 +49,12 @@ class SpeedySpeech(nn.Module):
positional_encoding=True,
length_scale=1,
encoder_type="residual_conv_bn",
encoder_params={"kernel_size": 4, "dilations": 4 * [1, 2, 4] + [1], "num_conv_blocks": 2, "num_res_blocks": 13},
encoder_params={
"kernel_size": 4,
"dilations": 4 * [1, 2, 4] + [1],
"num_conv_blocks": 2,
"num_res_blocks": 13
},
decoder_type="residual_conv_bn",
decoder_params={
"kernel_size": 4,
@ -60,13 +68,17 @@ class SpeedySpeech(nn.Module):
):
super().__init__()
self.length_scale = float(length_scale) if isinstance(length_scale, int) else length_scale
self.length_scale = float(length_scale) if isinstance(
length_scale, int) else length_scale
self.emb = nn.Embedding(num_chars, hidden_channels)
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type, encoder_params, c_in_channels)
self.encoder = Encoder(hidden_channels, hidden_channels, encoder_type,
encoder_params, c_in_channels)
if positional_encoding:
self.pos_encoder = PositionalEncoding(hidden_channels)
self.decoder = Decoder(out_channels, hidden_channels, decoder_type, decoder_params)
self.duration_predictor = DurationPredictor(hidden_channels + c_in_channels)
self.decoder = Decoder(out_channels, hidden_channels, decoder_type,
decoder_params)
self.duration_predictor = DurationPredictor(hidden_channels +
c_in_channels)
if num_speakers > 1 and not external_c:
# speaker embedding layer
@ -93,7 +105,9 @@ class SpeedySpeech(nn.Module):
"""
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(y_mask, 2)
attn = generate_path(dr, attn_mask.squeeze(1)).to(en.dtype)
o_en_ex = torch.matmul(attn.squeeze(1).transpose(1, 2), en.transpose(1, 2)).transpose(1, 2)
o_en_ex = torch.matmul(
attn.squeeze(1).transpose(1, 2), en.transpose(1,
2)).transpose(1, 2)
return o_en_ex, attn
def format_durations(self, o_dr_log, x_mask):
@ -127,7 +141,8 @@ class SpeedySpeech(nn.Module):
x_emb = torch.transpose(x_emb, 1, -1)
# compute sequence masks
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]), 1).to(x.dtype)
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.shape[1]),
1).to(x.dtype)
# encoder pass
o_en = self.encoder(x_emb, x_mask)
@ -140,7 +155,8 @@ class SpeedySpeech(nn.Module):
return o_en, o_en_dp, x_mask, g
def _forward_decoder(self, o_en, o_en_dp, dr, x_mask, y_lengths, g):
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(o_en_dp.dtype)
y_mask = torch.unsqueeze(sequence_mask(y_lengths, None),
1).to(o_en_dp.dtype)
# expand o_en with durations
o_en_ex, attn = self.expand_encoder_outputs(o_en, dr, x_mask, y_mask)
# positional encoding
@ -153,8 +169,17 @@ class SpeedySpeech(nn.Module):
o_de = self.decoder(o_en_ex, y_mask, g=g)
return o_de, attn.transpose(1, 2)
def forward(self, x, x_lengths, y_lengths, dr, g=None): # pylint: disable=unused-argument
def forward(self,
x,
x_lengths,
y_lengths,
dr,
cond_input={
'x_vectors': None,
'speaker_ids': None
}): # pylint: disable=unused-argument
"""
TODO: speaker embedding for speaker_ids
Shapes:
x: [B, T_max]
x_lengths: [B]
@ -162,35 +187,113 @@ class SpeedySpeech(nn.Module):
dr: [B, T_max]
g: [B, C]
"""
g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
o_de, attn = self._forward_decoder(o_en, o_en_dp, dr, x_mask, y_lengths, g=g)
return o_de, o_dr_log.squeeze(1), attn
o_de, attn = self._forward_decoder(o_en,
o_en_dp,
dr,
x_mask,
y_lengths,
g=g)
outputs = {
'model_outputs': o_de.transpose(1, 2),
'durations_log': o_dr_log.squeeze(1),
'alignments': attn
}
return outputs
def inference(self, x, x_lengths, g=None): # pylint: disable=unused-argument
def inference(self,
x,
cond_input={
'x_vectors': None,
'speaker_ids': None
}): # pylint: disable=unused-argument
"""
Shapes:
x: [B, T_max]
x_lengths: [B]
g: [B, C]
"""
g = cond_input['x_vectors'] if 'x_vectors' in cond_input else None
x_lengths = torch.tensor(x.shape[1:2]).to(x.device)
# input sequence should be greated than the max convolution size
inference_padding = 5
if x.shape[1] < 13:
inference_padding += 13 - x.shape[1]
# pad input to prevent dropping the last word
x = torch.nn.functional.pad(x, pad=(0, inference_padding), mode="constant", value=0)
x = torch.nn.functional.pad(x,
pad=(0, inference_padding),
mode="constant",
value=0)
o_en, o_en_dp, x_mask, g = self._forward_encoder(x, x_lengths, g)
# duration predictor pass
o_dr_log = self.duration_predictor(o_en_dp.detach(), x_mask)
o_dr = self.format_durations(o_dr_log, x_mask).squeeze(1)
y_lengths = o_dr.sum(1)
o_de, attn = self._forward_decoder(o_en, o_en_dp, o_dr, x_mask, y_lengths, g=g)
return o_de, attn
o_de, attn = self._forward_decoder(o_en,
o_en_dp,
o_dr,
x_mask,
y_lengths,
g=g)
outputs = {
'model_outputs': o_de.transpose(1, 2),
'alignments': attn,
'durations_log': None
}
return outputs
def load_checkpoint(
self, config, checkpoint_path, eval=False
): # pylint: disable=unused-argument, redefined-builtin
def train_step(self, batch: dict, criterion: nn.Module):
text_input = batch['text_input']
text_lengths = batch['text_lengths']
mel_input = batch['mel_input']
mel_lengths = batch['mel_lengths']
x_vectors = batch['x_vectors']
speaker_ids = batch['speaker_ids']
durations = batch['durations']
cond_input = {'x_vectors': x_vectors, 'speaker_ids': speaker_ids}
outputs = self.forward(text_input, text_lengths, mel_lengths,
durations, cond_input)
# compute loss
loss_dict = criterion(outputs['model_outputs'], mel_input,
mel_lengths, outputs['durations_log'],
torch.log(1 + durations), text_lengths)
# compute alignment error (the lower the better )
align_error = 1 - alignment_diagonal_score(outputs['alignments'],
binary=True)
loss_dict["align_error"] = align_error
return outputs, loss_dict
def train_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
model_outputs = outputs['model_outputs']
alignments = outputs['alignments']
mel_input = batch['mel_input']
pred_spec = model_outputs[0].data.cpu().numpy()
gt_spec = mel_input[0].data.cpu().numpy()
align_img = alignments[0].data.cpu().numpy()
figures = {
"prediction": plot_spectrogram(pred_spec, ap, output_fig=False),
"ground_truth": plot_spectrogram(gt_spec, ap, output_fig=False),
"alignment": plot_alignment(align_img, output_fig=False),
}
# Sample audio
train_audio = ap.inv_melspectrogram(pred_spec.T)
return figures, train_audio
def eval_step(self, batch: dict, criterion: nn.Module):
return self.train_step(batch, criterion)
def eval_log(self, ap: AudioProcessor, batch: dict, outputs: dict):
return self.train_log(ap, batch, outputs)
def load_checkpoint(self, config, checkpoint_path, eval=False): # pylint: disable=unused-argument, redefined-builtin
state = torch.load(checkpoint_path, map_location=torch.device("cpu"))
self.load_state_dict(state["model"])
if eval: