diff --git a/models/tacotron.py b/models/tacotron.py index bd2a3ac7..8f40f313 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -44,7 +44,7 @@ class Tacotron(nn.Module): self.postnet = PostCBHG(mel_dim) self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim) - def __init_states(self): + def _init_states(self): self.speaker_embeddings = None self.speaker_embeddings_projected = None @@ -59,7 +59,7 @@ class Tacotron(nn.Module): B = characters.size(0) mask = sequence_mask(text_lengths).to(characters.device) inputs = self.embedding(characters) - self.__init_states() + self._init_states() self.compute_speaker_embedding(speaker_ids) if self.num_speakers > 1: inputs = self._concat_speaker_embedding(inputs, @@ -78,7 +78,7 @@ class Tacotron(nn.Module): def inference(self, characters, speaker_ids=None): B = characters.size(0) inputs = self.embedding(characters) - self.__init_states() + self._init_states() self.compute_speaker_embedding(speaker_ids) if self.num_speakers > 1: inputs = self._concat_speaker_embedding(inputs, @@ -98,10 +98,16 @@ class Tacotron(nn.Module): speaker_embeddings = self.speaker_embedding(speaker_ids) return speaker_embeddings.unsqueeze_(1) - def _concat_speaker_embedding(self, outputs, speaker_embeddings): + def _add_speaker_embedding(self, outputs, speaker_embeddings): speaker_embeddings_ = speaker_embeddings.expand(outputs.size(0), - outputs.size(1), - -1) - outputs = torch.cat([outputs, speaker_embeddings_], dim=-1) + outputs.size(1), + -1) + outputs = outputs + speaker_embeddings_ return outputs + def _concat_speaker_embedding(self, outputs, speaker_embeddings): + speaker_embeddings_ = speaker_embeddings.expand(outputs.size(0), + outputs.size(1), + -1) + outputs = torch.cat([outputs, speaker_embeddings_], dim=-1) + return outputs diff --git a/models/tacotrongst.py b/models/tacotrongst.py index 5ea389d9..7d2fc626 100644 --- a/models/tacotrongst.py +++ b/models/tacotrongst.py @@ -1,11 +1,13 @@ # coding: utf-8 +import torch from torch import nn from TTS.layers.tacotron import Encoder, Decoder, PostCBHG from TTS.layers.gst_layers import GST from TTS.utils.generic_utils import sequence_mask +from TTS.models.tacotron import Tacotron -class TacotronGST(nn.Module): +class TacotronGST(Tacotron): def __init__(self, num_chars, num_speakers, @@ -22,37 +24,49 @@ class TacotronGST(nn.Module): forward_attn_mask=False, location_attn=True, separate_stopnet=True): - super(TacotronGST, self).__init__() - self.r = r - self.mel_dim = mel_dim - self.linear_dim = linear_dim - self.embedding = nn.Embedding(num_chars, 256) - self.embedding.weight.data.normal_(0, 0.3) - if num_speakers > 1: - self.speaker_embedding = nn.Embedding(num_speakers, 256) - self.speaker_embedding.weight.data.normal_(0, 0.3) - self.encoder = Encoder(256) - self.gst = GST(num_mel=80, num_heads=4, num_style_tokens=10, embedding_dim=256) - self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win, + super().__init__(num_chars, + num_speakers, + r, + linear_dim, + mel_dim, + memory_size, + attn_win, + attn_norm, + prenet_type, + prenet_dropout, + forward_attn, + trans_agent, + forward_attn_mask, + location_attn, + separate_stopnet) + gst_embedding_dim = 256 + decoder_dim = 512 + gst_embedding_dim if num_speakers > 1 else 256 + gst_embedding_dim + proj_speaker_dim = 80 if num_speakers > 1 else 0 + self.decoder = Decoder(decoder_dim, mel_dim, r, memory_size, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, forward_attn_mask, - location_attn, separate_stopnet) - self.postnet = PostCBHG(mel_dim) - self.last_linear = nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim) - + location_attn, separate_stopnet, proj_speaker_dim) + self.gst = GST(num_mel=80, num_heads=4, + num_style_tokens=10, embedding_dim=gst_embedding_dim) def forward(self, characters, text_lengths, mel_specs, speaker_ids=None): B = characters.size(0) mask = sequence_mask(text_lengths).to(characters.device) inputs = self.embedding(characters) + self._init_states() + self.compute_speaker_embedding(speaker_ids) + if self.num_speakers > 1: + inputs = self._concat_speaker_embedding(inputs, + self.speaker_embeddings) encoder_outputs = self.encoder(inputs) - encoder_outputs = self._add_speaker_embedding(encoder_outputs, - speaker_ids) + if self.num_speakers > 1: + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, + self.speaker_embeddings) gst_outputs = self.gst(mel_specs) - gst_outputs = gst_outputs.expand(-1, encoder_outputs.size(1), -1) - encoder_outputs = encoder_outputs + gst_outputs + encoder_outputs = self._concat_speaker_embedding( + encoder_outputs, gst_outputs) mel_outputs, alignments, stop_tokens = self.decoder( - encoder_outputs, mel_specs, mask) + encoder_outputs, mel_specs, mask, self.speaker_embeddings_projected) mel_outputs = mel_outputs.view(B, -1, self.mel_dim) linear_outputs = self.postnet(mel_outputs) linear_outputs = self.last_linear(linear_outputs) @@ -61,27 +75,23 @@ class TacotronGST(nn.Module): def inference(self, characters, speaker_ids=None, style_mel=None): B = characters.size(0) inputs = self.embedding(characters) + self._init_states() + self.compute_speaker_embedding(speaker_ids) + if self.num_speakers > 1: + inputs = self._concat_speaker_embedding(inputs, + self.speaker_embeddings) encoder_outputs = self.encoder(inputs) - encoder_outputs = self._add_speaker_embedding(encoder_outputs, - speaker_ids) + if self.num_speakers > 1: + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, + self.speaker_embeddings) if style_mel is not None: gst_outputs = self.gst(style_mel) gst_outputs = gst_outputs.expand(-1, encoder_outputs.size(1), -1) - encoder_outputs = encoder_outputs + gst_outputs + encoder_outputs = self._concat_speaker_embedding(encoder_outputs, + gst_outputs) mel_outputs, alignments, stop_tokens = self.decoder.inference( - encoder_outputs) + encoder_outputs, self.speaker_embeddings_projected) mel_outputs = mel_outputs.view(B, -1, self.mel_dim) linear_outputs = self.postnet(mel_outputs) linear_outputs = self.last_linear(linear_outputs) return mel_outputs, linear_outputs, alignments, stop_tokens - - def _add_speaker_embedding(self, encoder_outputs, speaker_ids): - if hasattr(self, "speaker_embedding") and speaker_ids is not None: - speaker_embeddings = self.speaker_embedding(speaker_ids) - - speaker_embeddings.unsqueeze_(1) - speaker_embeddings = speaker_embeddings.expand(encoder_outputs.size(0), - encoder_outputs.size(1), - -1) - encoder_outputs = encoder_outputs + speaker_embeddings - return encoder_outputs diff --git a/tests/test_tacotron_model.py b/tests/test_tacotron_model.py index acd7af41..9b8de336 100644 --- a/tests/test_tacotron_model.py +++ b/tests/test_tacotron_model.py @@ -8,6 +8,7 @@ from torch import nn from TTS.utils.generic_utils import load_config from TTS.layers.losses import L1LossMasked from TTS.models.tacotron import Tacotron +from TTS.models.tacotrongst import TacotronGST #pylint: disable=unused-variable @@ -24,15 +25,74 @@ def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) -class TacotronTrainTest(unittest.TestCase): +# class TacotronTrainTest(unittest.TestCase): + # def test_train_step(self): + # input = torch.randint(0, 24, (8, 128)).long().to(device) + # input_lengths = torch.randint(100, 129, (8, )).long().to(device) + # input_lengths[-1] = 128 + # mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) + # linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device) + # mel_lengths = torch.randint(20, 30, (8, )).long().to(device) + # stop_targets = torch.zeros(8, 30, 1).float().to(device) + # speaker_ids = torch.randint(0, 5, (8, )).long().to(device) + + # for idx in mel_lengths: + # stop_targets[:, int(idx.item()):, 0] = 1.0 + + # stop_targets = stop_targets.view(input.shape[0], + # stop_targets.size(1) // c.r, -1) + # stop_targets = (stop_targets.sum(2) > + # 0.0).unsqueeze(2).float().squeeze() + + # criterion = L1LossMasked().to(device) + # criterion_st = nn.BCEWithLogitsLoss().to(device) + # model = Tacotron( + # num_chars=32, + # num_speakers=5, + # linear_dim=c.audio['num_freq'], + # mel_dim=c.audio['num_mels'], + # r=c.r, + # memory_size=c.memory_size).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor + # model.train() + # print(" > Num parameters for Tacotron model:%s"%(count_parameters(model))) + # model_ref = copy.deepcopy(model) + # count = 0 + # for param, param_ref in zip(model.parameters(), + # model_ref.parameters()): + # assert (param - param_ref).sum() == 0, param + # count += 1 + # optimizer = optim.Adam(model.parameters(), lr=c.lr) + # for _ in range(5): + # mel_out, linear_out, align, stop_tokens = model.forward( + # input, input_lengths, mel_spec, speaker_ids) + # optimizer.zero_grad() + # loss = criterion(mel_out, mel_spec, mel_lengths) + # stop_loss = criterion_st(stop_tokens, stop_targets) + # loss = loss + criterion(linear_out, linear_spec, + # mel_lengths) + stop_loss + # loss.backward() + # optimizer.step() + # # check parameter changes + # count = 0 + # for param, param_ref in zip(model.parameters(), + # model_ref.parameters()): + # # ignore pre-higway layer since it works conditional + # # if count not in [145, 59]: + # assert (param != param_ref).any( + # ), "param {} with shape {} not updated!! \n{}\n{}".format( + # count, param.shape, param, param_ref) + # count += 1 + + +class TacotronGSTTrainTest(unittest.TestCase): def test_train_step(self): input = torch.randint(0, 24, (8, 128)).long().to(device) input_lengths = torch.randint(100, 129, (8, )).long().to(device) input_lengths[-1] = 128 - mel_spec = torch.rand(8, 30, c.audio['num_mels']).to(device) - linear_spec = torch.rand(8, 30, c.audio['num_freq']).to(device) - mel_lengths = torch.randint(20, 30, (8, )).long().to(device) - stop_targets = torch.zeros(8, 30, 1).float().to(device) + mel_spec = torch.rand(8, 120, c.audio['num_mels']).to(device) + linear_spec = torch.rand(8, 120, c.audio['num_freq']).to(device) + mel_lengths = torch.randint(20, 120, (8, )).long().to(device) + stop_targets = torch.zeros(8, 120, 1).float().to(device) speaker_ids = torch.randint(0, 5, (8, )).long().to(device) for idx in mel_lengths: @@ -45,7 +105,7 @@ class TacotronTrainTest(unittest.TestCase): criterion = L1LossMasked().to(device) criterion_st = nn.BCEWithLogitsLoss().to(device) - model = Tacotron( + model = TacotronGST( num_chars=32, num_speakers=5, linear_dim=c.audio['num_freq'], @@ -53,7 +113,8 @@ class TacotronTrainTest(unittest.TestCase): r=c.r, memory_size=c.memory_size).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor model.train() - print(" > Num parameters for Tacotron model:%s"%(count_parameters(model))) + print(model) + print(" > Num parameters for Tacotron GST model:%s"%(count_parameters(model))) model_ref = copy.deepcopy(model) count = 0 for param, param_ref in zip(model.parameters(), @@ -61,7 +122,7 @@ class TacotronTrainTest(unittest.TestCase): assert (param - param_ref).sum() == 0, param count += 1 optimizer = optim.Adam(model.parameters(), lr=c.lr) - for _ in range(5): + for _ in range(10): mel_out, linear_out, align, stop_tokens = model.forward( input, input_lengths, mel_spec, speaker_ids) optimizer.zero_grad() @@ -76,7 +137,6 @@ class TacotronTrainTest(unittest.TestCase): for param, param_ref in zip(model.parameters(), model_ref.parameters()): # ignore pre-higway layer since it works conditional - # if count not in [145, 59]: assert (param != param_ref).any( ), "param {} with shape {} not updated!! \n{}\n{}".format( count, param.shape, param, param_ref)