From a1322530dfd63ec2b9433699b36a647774f9aaa5 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Thu, 12 Sep 2019 10:39:15 +0200 Subject: [PATCH] integrade concatinative speker embedding to tacotron --- layers/tacotron.py | 17 +++++++---- models/tacotron.py | 70 +++++++++++++++++++++++++++++++------------- tests/test_layers.py | 31 +++++++++++++++++++- 3 files changed, 91 insertions(+), 27 deletions(-) diff --git a/layers/tacotron.py b/layers/tacotron.py index 788e5230..411e7e72 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -273,7 +273,7 @@ class Decoder(nn.Module): def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, location_attn, - separate_stopnet): + separate_stopnet, speaker_embedding_dim): super(Decoder, self).__init__() self.r_init = r self.r = r @@ -285,8 +285,9 @@ class Decoder(nn.Module): self.separate_stopnet = separate_stopnet self.query_dim = 256 # memory -> |Prenet| -> processed_memory + prenet_dim = memory_dim * self.memory_size + speaker_embedding_dim if self.use_memory_queue else memory_dim + speaker_embedding_dim self.prenet = Prenet( - memory_dim * self.memory_size if self.use_memory_queue else memory_dim, + prenet_dim, prenet_type, prenet_dropout, out_features=[256, 128]) @@ -407,7 +408,7 @@ class Decoder(nn.Module): # use only the last frame prediction self.memory_input = new_memory[:, :self.memory_dim] - def forward(self, inputs, memory, mask): + def forward(self, inputs, memory, mask, speaker_embeddings=None): """ Args: inputs: Encoder outputs. @@ -432,6 +433,8 @@ class Decoder(nn.Module): if t > 0: new_memory = memory[t - 1] self._update_memory_input(new_memory) + if speaker_embeddings is not None: + self.memory_input = torch.cat([self.memory_input, speaker_embeddings], dim=-1) output, stop_token, attention = self.decode(inputs, mask) outputs += [output] attentions += [attention] @@ -440,13 +443,15 @@ class Decoder(nn.Module): return self._parse_outputs(outputs, attentions, stop_tokens) - def inference(self, inputs): + def inference(self, inputs, speaker_embeddings=None): """ Args: - inputs: Encoder outputs. + inputs: encoder outputs. + speaker_embeddings: speaker vectors. Shapes: - inputs: batch x time x encoder_out_dim + - speaker_embeddings: batch x embed_dim """ outputs = [] attentions = [] @@ -459,6 +464,8 @@ class Decoder(nn.Module): if t > 0: new_memory = outputs[-1] self._update_memory_input(new_memory) + if speaker_embeddings is not None: + self.memory_input = torch.cat([self.memory_input, speaker_embeddings], dim=-1) output, stop_token, attention = self.decode(inputs, None) stop_token = torch.sigmoid(stop_token.data) outputs += [output] diff --git a/models/tacotron.py b/models/tacotron.py index 69a6fa03..bd2a3ac7 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -1,4 +1,5 @@ # coding: utf-8 +import torch from torch import nn from TTS.layers.tacotron import Encoder, Decoder, PostCBHG from TTS.utils.generic_utils import sequence_mask @@ -25,28 +26,50 @@ class Tacotron(nn.Module): self.r = r self.mel_dim = mel_dim self.linear_dim = linear_dim + self.num_speakers = num_speakers self.embedding = nn.Embedding(num_chars, 256) self.embedding.weight.data.normal_(0, 0.3) + decoder_dim = 512 if num_speakers > 1 else 256 + encoder_dim = 512 if num_speakers > 1 else 256 + proj_speaker_dim = 80 if num_speakers > 1 else 0 if num_speakers > 1: self.speaker_embedding = nn.Embedding(num_speakers, 256) self.speaker_embedding.weight.data.normal_(0, 0.3) - self.encoder = Encoder(256) - self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win, + self.speaker_project_mel = nn.Sequential(nn.Linear(256, proj_speaker_dim), nn.Tanh()) + self.encoder = Encoder(encoder_dim) + self.decoder = Decoder(decoder_dim, mel_dim, r, memory_size, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, - location_attn, separate_stopnet) + location_attn, separate_stopnet, proj_speaker_dim) self.postnet = PostCBHG(mel_dim) self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim) - + + def __init_states(self): + self.speaker_embeddings = None + self.speaker_embeddings_projected = None + + def compute_speaker_embedding(self, speaker_ids): + if hasattr(self, "speaker_embedding") and speaker_ids is None: + raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided") + if hasattr(self, "speaker_embedding") and speaker_ids is not None: + self.speaker_embeddings = self._compute_speaker_embedding(speaker_ids) + self.speaker_embeddings_projected = self.speaker_project_mel(self.speaker_embeddings).squeeze(1) + def forward(self, characters, text_lengths, mel_specs, speaker_ids=None): B = characters.size(0) mask = sequence_mask(text_lengths).to(characters.device) inputs = self.embedding(characters) + self.__init_states() + self.compute_speaker_embedding(speaker_ids) + if self.num_speakers > 1: + inputs = self._concat_speaker_embedding(inputs, + self.speaker_embeddings) encoder_outputs = self.encoder(inputs) - encoder_outputs = self._add_speaker_embedding(encoder_outputs, - speaker_ids) + if self.num_speakers > 1: + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, + self.speaker_embeddings) mel_outputs, alignments, stop_tokens = self.decoder( - encoder_outputs, mel_specs, mask) + encoder_outputs, mel_specs, mask, self.speaker_embeddings_projected) mel_outputs = mel_outputs.view(B, -1, self.mel_dim) linear_outputs = self.postnet(mel_outputs) linear_outputs = self.last_linear(linear_outputs) @@ -55,25 +78,30 @@ class Tacotron(nn.Module): def inference(self, characters, speaker_ids=None): B = characters.size(0) inputs = self.embedding(characters) + self.__init_states() + self.compute_speaker_embedding(speaker_ids) + if self.num_speakers > 1: + inputs = self._concat_speaker_embedding(inputs, + self.speaker_embeddings) encoder_outputs = self.encoder(inputs) - encoder_outputs = self._add_speaker_embedding(encoder_outputs, - speaker_ids) + if self.num_speakers > 1: + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, + self.speaker_embeddings) mel_outputs, alignments, stop_tokens = self.decoder.inference( - encoder_outputs) + encoder_outputs, self.speaker_embeddings_projected) mel_outputs = mel_outputs.view(B, -1, self.mel_dim) linear_outputs = self.postnet(mel_outputs) linear_outputs = self.last_linear(linear_outputs) return mel_outputs, linear_outputs, alignments, stop_tokens - def _add_speaker_embedding(self, encoder_outputs, speaker_ids): - if hasattr(self, "speaker_embedding") and speaker_ids is None: - raise RuntimeError(" [!] Model has speaker embedding layer but speaker_id is not provided") - if hasattr(self, "speaker_embedding") and speaker_ids is not None: - speaker_embeddings = self.speaker_embedding(speaker_ids) + def _compute_speaker_embedding(self, speaker_ids): + speaker_embeddings = self.speaker_embedding(speaker_ids) + return speaker_embeddings.unsqueeze_(1) + + def _concat_speaker_embedding(self, outputs, speaker_embeddings): + speaker_embeddings_ = speaker_embeddings.expand(outputs.size(0), + outputs.size(1), + -1) + outputs = torch.cat([outputs, speaker_embeddings_], dim=-1) + return outputs - speaker_embeddings.unsqueeze_(1) - speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0), - encoder_outputs.size(1), - -1) - encoder_outputs = encoder_outputs + speaker_embeddings - return encoder_outputs diff --git a/tests/test_layers.py b/tests/test_layers.py index cf27e30c..a465a898 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -54,7 +54,8 @@ class DecoderTests(unittest.TestCase): trans_agent=True, forward_attn_mask=True, location_attn=True, - separate_stopnet=True) + separate_stopnet=True, + speaker_embedding_dim=0) dummy_input = T.rand(4, 8, 256) dummy_memory = T.rand(4, 2, 80) @@ -66,6 +67,34 @@ class DecoderTests(unittest.TestCase): assert output.shape[2] == 80 * 2, "size not {}".format(output.shape[2]) assert stop_tokens.shape[0] == 4 + def test_in_out_multispeaker(self): + layer = Decoder( + in_features=256, + memory_dim=80, + r=2, + memory_size=4, + attn_windowing=False, + attn_norm="sigmoid", + prenet_type='original', + prenet_dropout=True, + forward_attn=True, + trans_agent=True, + forward_attn_mask=True, + location_attn=True, + separate_stopnet=True, + speaker_embedding_dim=80) + dummy_input = T.rand(4, 8, 256) + dummy_memory = T.rand(4, 2, 80) + dummy_embed = T.rand(4, 80) + + output, alignment, stop_tokens = layer( + dummy_input, dummy_memory, mask=None, speaker_embeddings=dummy_embed) + + assert output.shape[0] == 4 + assert output.shape[1] == 1, "size not {}".format(output.shape[1]) + assert output.shape[2] == 80 * 2, "size not {}".format(output.shape[2]) + assert stop_tokens.shape[0] == 4 + class EncoderTests(unittest.TestCase): def test_in_out(self):