From 3071e7f6f6b4336fa1c3a9744dacb771fbb0451a Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Mon, 19 Mar 2018 08:26:16 -0700 Subject: [PATCH] remove attention mask --- config.json | 8 ++++---- layers/tacotron.py | 12 ++---------- models/tacotron.py | 10 +++------- train.py | 8 +++----- 4 files changed, 12 insertions(+), 26 deletions(-) diff --git a/config.json b/config.json index 29b62a11..ffea4466 100644 --- a/config.json +++ b/config.json @@ -12,20 +12,20 @@ "text_cleaner": "english_cleaners", "epochs": 2000, - "lr": 0.00001875, + "lr": 0.001, "warmup_steps": 4000, - "batch_size": 2, + "batch_size": 32, "eval_batch_size": 32, "r": 5, "griffin_lim_iters": 60, "power": 1.5, - "num_loader_workers": 16, + "num_loader_workers": 12, "checkpoint": false, "save_step": 69, "data_path": "/run/shm/erogol/LJSpeech-1.0", - "min_seq_len": 90, + "min_seq_len": 0, "output_path": "result" } diff --git a/layers/tacotron.py b/layers/tacotron.py index 6f5926a8..38471214 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -231,8 +231,8 @@ class Decoder(nn.Module): # RNN_state -> |Linear| -> mel_spec self.proj_to_mel = nn.Linear(256, memory_dim * r) - def forward(self, inputs, memory=None, input_lengths=None): - r""" + def forward(self, inputs, memory=None): + """ Decoder forward step. If decoder inputs are not given (e.g., at testing time), as noted in @@ -242,8 +242,6 @@ class Decoder(nn.Module): inputs: Encoder outputs. memory (None): Decoder memory (autoregression. If None (at eval-time), decoder outputs are used as decoder inputs. - input_lengths (None): input lengths, used for - attention masking. Shapes: - inputs: batch x time x encoder_out_dim @@ -251,12 +249,6 @@ class Decoder(nn.Module): """ B = inputs.size(0) - - # if input_lengths is not None: - # mask = get_mask_from_lengths(processed_inputs, input_lengths) - # else: - # mask = None - # Run greedy decoding if memory is None greedy = memory is None diff --git a/models/tacotron.py b/models/tacotron.py index 57c9b43d..0b55b76b 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -8,12 +8,11 @@ from TTS.layers.tacotron import Prenet, Encoder, Decoder, CBHG class Tacotron(nn.Module): def __init__(self, embedding_dim=256, linear_dim=1025, mel_dim=80, - freq_dim=1025, r=5, padding_idx=None, - use_atten_mask=False): + freq_dim=1025, r=5, padding_idx=None): + super(Tacotron, self).__init__() self.mel_dim = mel_dim self.linear_dim = linear_dim - self.use_atten_mask = use_atten_mask self.embedding = nn.Embedding(len(symbols), embedding_dim, padding_idx=padding_idx) print(" | > Embedding dim : {}".format(len(symbols))) @@ -26,16 +25,13 @@ class Tacotron(nn.Module): self.postnet = CBHG(mel_dim, K=8, projections=[256, mel_dim]) self.last_linear = nn.Linear(mel_dim * 2, freq_dim) - def forward(self, characters, mel_specs=None, input_lengths=None): + def forward(self, characters, mel_specs=None): B = characters.size(0) inputs = self.embedding(characters) # (B, T', in_dim) encoder_outputs = self.encoder(inputs) - if not self.use_atten_mask: - input_lengths = None - # (B, T', mel_dim*r) mel_outputs, alignments = self.decoder( encoder_outputs, mel_specs, input_lengths=input_lengths) diff --git a/train.py b/train.py index 3b7ff638..7b32d74c 100644 --- a/train.py +++ b/train.py @@ -112,8 +112,7 @@ def train(model, criterion, data_loader, optimizer, epoch): # forward pass mel_output, linear_output, alignments =\ - model.forward(text_input_var, mel_spec_var, - input_lengths= torch.autograd.Variable(torch.cuda.LongTensor(sorted_lengths))) + model.forward(text_input_var, mel_spec_var) # loss computation mel_loss = criterion(mel_output, mel_spec_var) @@ -337,9 +336,8 @@ def main(args): c.hidden_size, c.num_mels, c.num_freq, - c.r, - use_atten_mask=True) - + c.r) + optimizer = optim.Adam(model.parameters(), lr=c.lr) if use_cuda: