diff --git a/layers/tacotron.py b/layers/tacotron.py index 0634e28a..6c5be799 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -103,8 +103,8 @@ class CBHG(nn.Module): num_highways (int): number of highways layers Shapes: - - input: batch x time x dim - - output: batch x time x dim*2 + - input: B x D x T_in + - output: B x T_in x D*2 """ def __init__(self, diff --git a/layers/tacotron2.py b/layers/tacotron2.py index a1d76f7d..48607aea 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -100,7 +100,7 @@ class Decoder(nn.Module): #pylint: disable=attribute-defined-outside-init def __init__(self, in_features, memory_dim, r, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, - forward_attn_mask, location_attn, separate_stopnet, + forward_attn_mask, location_attn, separate_stopnet, speaker_embedding_dim): super(Decoder, self).__init__() self.memory_dim = memory_dim @@ -117,7 +117,7 @@ class Decoder(nn.Module): self.p_decoder_dropout = 0.1 # memory -> |Prenet| -> processed_memory - prenet_dim = self.memory_dim + speaker_embedding_dim + prenet_dim = self.memory_dim self.prenet = Prenet( prenet_dim, prenet_type, @@ -244,7 +244,10 @@ class Decoder(nn.Module): memory = self.get_go_frame(inputs).unsqueeze(0) memories = self._reshape_memory(memories) memories = torch.cat((memory, memories), dim=0) - memories = self.prenet(self._update_memory(memories)) + memories = self._update_memory(memories) + if speaker_embeddings is not None: + memories = torch.cat([memories, speaker_embeddings], dim=-1) + memories = self.prenet(memories) self._init_states(inputs, mask=mask) self.attention.init_states(inputs) @@ -252,8 +255,6 @@ class Decoder(nn.Module): outputs, stop_tokens, alignments = [], [], [] while len(outputs) < memories.size(0) - 1: memory = memories[len(outputs)] - if speaker_embeddings is not None: - memory = torch.cat([memory, speaker_embeddings], dim=-1) mel_output, attention_weights, stop_token = self.decode(memory) outputs += [mel_output.squeeze(1)] stop_tokens += [stop_token.squeeze(1)] @@ -277,7 +278,7 @@ class Decoder(nn.Module): while True: memory = self.prenet(memory) if speaker_embeddings is not None: - memory = torch.cat([memory, speaker_embeddings], dim=-1) + memory = torch.cat([memory, speaker_embeddings], dim=-1) mel_output, alignment, stop_token = self.decode(memory) stop_token = torch.sigmoid(stop_token.data) outputs += [mel_output.squeeze(1)] diff --git a/models/tacotron.py b/models/tacotron.py index 52d5bf3a..d726ac03 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -96,7 +96,6 @@ class Tacotron(nn.Module): - speaker_ids: B x 1 """ self._init_states() - B = characters.size(0) mask = sequence_mask(text_lengths).to(characters.device) # B x T_in x embed_dim inputs = self.embedding(characters) @@ -132,14 +131,13 @@ class Tacotron(nn.Module): return decoder_outputs, postnet_outputs, alignments, stop_tokens def inference(self, characters, speaker_ids=None, style_mel=None): - B = characters.size(0) inputs = self.embedding(characters) self._init_states() self.compute_speaker_embedding(speaker_ids) if self.num_speakers > 1: inputs = self._concat_speaker_embedding(inputs, self.speaker_embeddings) - encoder_outputs = self.encoder(inputs) + encoder_outputs = self.encoder(inputs) if self.gst and style_mel is not None: encoder_outputs = self.compute_gst(encoder_outputs, style_mel) if self.num_speakers > 1: diff --git a/models/tacotron2.py b/models/tacotron2.py index d8a2e473..c885b8ed 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -28,8 +28,8 @@ class Tacotron2(nn.Module): self.decoder_output_dim = decoder_output_dim self.n_frames_per_step = r self.bidirectional_decoder = bidirectional_decoder - decoder_dim = 512 + 256 if num_speakers > 1 else 512 - encoder_dim = 512 + 256 if num_speakers > 1 else 512 + decoder_dim = 512 if num_speakers > 1 else 512 + encoder_dim = 512 if num_speakers > 1 else 512 proj_speaker_dim = 80 if num_speakers > 1 else 0 # embedding layer self.embedding = nn.Embedding(num_chars, 512) @@ -39,6 +39,8 @@ class Tacotron2(nn.Module): if num_speakers > 1: self.speaker_embedding = nn.Embedding(num_speakers, 512) self.speaker_embedding.weight.data.normal_(0, 0.3) + self.speaker_embeddings = None + self.speaker_embeddings_projected = None self.encoder = Encoder(encoder_dim) self.decoder = Decoder(decoder_dim, self.decoder_output_dim, r, attn_win, attn_norm, prenet_type, prenet_dropout, @@ -47,7 +49,7 @@ class Tacotron2(nn.Module): if self.bidirectional_decoder: self.decoder_backward = copy.deepcopy(self.decoder) self.postnet = Postnet(self.decoder_output_dim) - + def _init_states(self): self.speaker_embeddings = None self.speaker_embeddings_projected = None diff --git a/tests/outputs/dummy_model_config.json b/tests/outputs/dummy_model_config.json index 7ce69645..09845621 100644 --- a/tests/outputs/dummy_model_config.json +++ b/tests/outputs/dummy_model_config.json @@ -44,6 +44,7 @@ "prenet_dropout": true, // ONLY TACOTRON2 - enable/disable dropout at prenet. "use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster. "forward_attn_mask": false, + "bidirectional_decoder": false, "transition_agent": false, // ONLY TACOTRON2 - enable/disable transition agent of forward attention. "location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default. "loss_masking": true, // enable / disable loss masking against the sequence padding. diff --git a/tests/test_layers.py b/tests/test_layers.py index 6b5fd80b..f9e60363 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -29,7 +29,8 @@ class CBHGTests(unittest.TestCase): highway_features=80, gru_features=80, num_highways=4) - dummy_input = T.rand(4, 8, 128) + # B x D x T + dummy_input = T.rand(4, 128, 8) print(layer) output = layer(dummy_input) @@ -63,8 +64,8 @@ class DecoderTests(unittest.TestCase): dummy_input, dummy_memory, mask=None) assert output.shape[0] == 4 - assert output.shape[1] == 1, "size not {}".format(output.shape[1]) - assert output.shape[2] == 80 * 2, "size not {}".format(output.shape[2]) + assert output.shape[1] == 80, "size not {}".format(output.shape[1]) + assert output.shape[2] == 2, "size not {}".format(output.shape[2]) assert stop_tokens.shape[0] == 4 @staticmethod @@ -92,8 +93,8 @@ class DecoderTests(unittest.TestCase): dummy_input, dummy_memory, mask=None, speaker_embeddings=dummy_embed) assert output.shape[0] == 4 - assert output.shape[1] == 1, "size not {}".format(output.shape[1]) - assert output.shape[2] == 80 * 2, "size not {}".format(output.shape[2]) + assert output.shape[1] == 80, "size not {}".format(output.shape[1]) + assert output.shape[2] == 2, "size not {}".format(output.shape[2]) assert stop_tokens.shape[0] == 4 diff --git a/tests/test_tacotron_model.py b/tests/test_tacotron_model.py index c8b0d7ca..7e5e8daf 100644 --- a/tests/test_tacotron_model.py +++ b/tests/test_tacotron_model.py @@ -49,8 +49,8 @@ class TacotronTrainTest(unittest.TestCase): model = Tacotron( num_chars=32, num_speakers=5, - linear_dim=c.audio['num_freq'], - mel_dim=c.audio['num_mels'], + postnet_output_dim=c.audio['num_freq'], + decoder_output_dim=c.audio['num_mels'], r=c.r, memory_size=c.memory_size ).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor @@ -112,8 +112,8 @@ class TacotronGSTTrainTest(unittest.TestCase): num_chars=32, num_speakers=5, gst=True, - linear_dim=c.audio['num_freq'], - mel_dim=c.audio['num_mels'], + postnet_output_dim=c.audio['num_freq'], + decoder_output_dim=c.audio['num_mels'], r=c.r, memory_size=c.memory_size ).to(device) #FIXME: missing num_speakers parameter to Tacotron ctor diff --git a/train.py b/train.py index 7f3e2ef4..45991015 100644 --- a/train.py +++ b/train.py @@ -80,8 +80,7 @@ def format_data(data): text_input = data[0] text_lengths = data[1] speaker_names = data[2] - linear_input = data[3] if c.model in ["Tacotron", "TacotronGST" - ] else None + linear_input = data[3] if c.model in ["Tacotron", "TacotronGST"] else None mel_input = data[4] mel_lengths = data[5] stop_targets = data[6] @@ -98,7 +97,7 @@ def format_data(data): # set stop targets view, we predict a single stop token per r frames prediction stop_targets = stop_targets.view(text_input.shape[0], - stop_targets.size(1) // c.r, -1) + stop_targets.size(1) // c.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) @@ -108,9 +107,7 @@ def format_data(data): text_lengths = text_lengths.cuda(non_blocking=True) mel_input = mel_input.cuda(non_blocking=True) mel_lengths = mel_lengths.cuda(non_blocking=True) - linear_input = linear_input.cuda( - non_blocking=True) if c.model in ["Tacotron", "TacotronGST" - ] else None + linear_input = linear_input.cuda(non_blocking=True) if c.model in ["Tacotron", "TacotronGST"] else None stop_targets = stop_targets.cuda(non_blocking=True) if speaker_ids is not None: speaker_ids = speaker_ids.cuda(non_blocking=True) @@ -352,8 +349,8 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): start_time = time.time() # format data - text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, avg_text_length, avg_spec_length = format_data(data) - assert mel_input.shape[1] % model.decoder.r == 0 + text_input, text_lengths, mel_input, mel_lengths, linear_input, stop_targets, speaker_ids, _, _ = format_data(data) + assert mel_input.shape[1] % model.decoder.r == 0 # forward pass model if c.bidirectional_decoder: @@ -622,7 +619,8 @@ def main(args): # pylint: disable=redefined-outer-name r, c.batch_size = gradual_training_scheduler(global_step, c) c.r = r model.decoder.set_r(r) - if c.bidirectional_decoder: model.decoder_backward.set_r(r) + if c.bidirectional_decoder: + model.decoder_backward.set_r(r) print(" > Number of outputs per iteration:", model.decoder.r) train_loss, global_step = train(model, criterion, criterion_st,