From 31fe02412cd5dbedb4dd6a8ef435d38b2f27f466 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 4 Jun 2019 00:39:29 +0200 Subject: [PATCH] forward_attn_mask and config update --- config_cluster.json | 8 +++++--- layers/common_layers.py | 6 ++++-- layers/tacotron.py | 5 +++-- layers/tacotron2.py | 16 ++++++++++++---- models/tacotron.py | 5 +++-- models/tacotron2.py | 5 +++-- utils/generic_utils.py | 2 ++ 7 files changed, 32 insertions(+), 15 deletions(-) diff --git a/config_cluster.json b/config_cluster.json index 2c05ca41..420b03ed 100644 --- a/config_cluster.json +++ b/config_cluster.json @@ -37,7 +37,6 @@ "lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate. "lr_decay": false, // if true, Noam learning rate decaying is applied through training. "warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr" - "windowing": false, // Enables attention windowing. Used only in eval mode. "memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5. "attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron. "prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn". @@ -45,11 +44,14 @@ "use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster. "transition_agent": false, // ONLY TACOTRON2 - enable/disable transition agent of forward attention. "location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default. - "loss_masking": true, // enable / disable loss masking against the sequence padding. + "loss_masking": true, // enable / disable loss masking against the sequence padding. "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. "stopnet": true, // Train stopnet predicting the end of synthesis. - "separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER. + "separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER. "tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging. + + "windowing": false, // Enables attention windowing. Used only in eval mode. + "forward_attn_masking": false, // Enable forward attention masking which improves attention stability. Use it if network does not work as you like when it is off. "batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention. "eval_batch_size":16, diff --git a/layers/common_layers.py b/layers/common_layers.py index 3f694463..061134e2 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -108,7 +108,7 @@ class Attention(nn.Module): def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, location_attention, attention_location_n_filters, attention_location_kernel_size, windowing, norm, forward_attn, - trans_agent): + trans_agent, forward_attn_mask): super(Attention, self).__init__() self.query_layer = Linear( attention_rnn_dim, attention_dim, bias=False, init_gain='tanh') @@ -128,6 +128,7 @@ class Attention(nn.Module): self.norm = norm self.forward_attn = forward_attn self.trans_agent = trans_agent + self.forward_attn_mask = forward_attn_mask self.location_attention = location_attention def init_win_idx(self): @@ -203,9 +204,10 @@ class Attention(nn.Module): alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-8) * alignment # force incremental alignment - TODO: make configurable - if not self.training: + if not self.training and self.forward_attn_mask: _, n = prev_alpha.max(1) val, n2 = alpha.max(1) + print(True) for b in range(alignment.shape[0]): alpha[b, n[b] + 2:] = 0 alpha[b, :(n[b] - 1)] = 0 # ignore all previous states to prevent repetition. diff --git a/layers/tacotron.py b/layers/tacotron.py index 8d762cc8..034f682b 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -271,7 +271,7 @@ class Decoder(nn.Module): def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing, attn_norm, prenet_type, prenet_dropout, forward_attn, - trans_agent, location_attn, separate_stopnet): + trans_agent, forward_attn_mask, location_attn, separate_stopnet): super(Decoder, self).__init__() self.r = r self.in_features = in_features @@ -296,7 +296,8 @@ class Decoder(nn.Module): windowing=attn_windowing, norm=attn_norm, forward_attn=forward_attn, - trans_agent=trans_agent) + trans_agent=trans_agent, + forward_attn_mask=forward_attn_mask) # (processed_memory | attention context) -> |Linear| -> decoder_RNN_input self.project_to_decoder_in = nn.Linear(256 + in_features, 256) # decoder_RNN_input -> |RNN| -> RNN_state diff --git a/layers/tacotron2.py b/layers/tacotron2.py index 0277734c..09bf5373 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -97,7 +97,7 @@ class Encoder(nn.Module): class Decoder(nn.Module): def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, prenet_dropout, forward_attn, trans_agent, - location_attn, separate_stopnet): + forward_attn_mask, location_attn, separate_stopnet): super(Decoder, self).__init__() self.mel_channels = inputs_dim self.r = r @@ -118,9 +118,17 @@ class Decoder(nn.Module): self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features, self.attention_rnn_dim) - self.attention_layer = Attention(self.attention_rnn_dim, in_features, - 128, location_attn, 32, 31, attn_win, - attn_norm, forward_attn, trans_agent) + self.attention_layer = Attention(attention_rnn_dim=self.attention_rnn_dim, + embedding_dim=in_features, + attention_dim=128, + location_attention=location_attn, + attention_location_n_filters=32, + attention_location_kernel_size=31, + windowing=attn_win, + norm=attn_norm, + forward_attn=forward_attn, + trans_agent=trans_agent, + forward_attn_mask=forward_attn_mask) self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features, self.decoder_rnn_dim, 1) diff --git a/models/tacotron.py b/models/tacotron.py index 362bf8b5..5d2af992 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -19,6 +19,7 @@ class Tacotron(nn.Module): prenet_dropout=True, forward_attn=False, trans_agent=False, + forward_attn_mask=False, location_attn=True, separate_stopnet=True): super(Tacotron, self).__init__() @@ -30,8 +31,8 @@ class Tacotron(nn.Module): self.encoder = Encoder(256) self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win, attn_norm, prenet_type, prenet_dropout, - forward_attn, trans_agent, location_attn, - separate_stopnet) + forward_attn, trans_agent, forward_attn_mask, + location_attn, separate_stopnet) self.postnet = PostCBHG(mel_dim) self.last_linear = nn.Sequential( nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim), diff --git a/models/tacotron2.py b/models/tacotron2.py index da4894fc..c306a174 100644 --- a/models/tacotron2.py +++ b/models/tacotron2.py @@ -18,6 +18,7 @@ class Tacotron2(nn.Module): prenet_dropout=True, forward_attn=False, trans_agent=False, + forward_attn_mask=False, location_attn=True, separate_stopnet=True): super(Tacotron2, self).__init__() @@ -30,8 +31,8 @@ class Tacotron2(nn.Module): self.encoder = Encoder(512) self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, prenet_dropout, - forward_attn, trans_agent, location_attn, - separate_stopnet) + forward_attn, trans_agent, forward_attn_mask, + location_attn, separate_stopnet) self.postnet = Postnet(self.n_mel_channels) def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments): diff --git a/utils/generic_utils.py b/utils/generic_utils.py index 77fb4dc2..60519871 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -260,6 +260,7 @@ def setup_model(num_chars, c): prenet_dropout=c.prenet_dropout, forward_attn=c.use_forward_attn, trans_agent=c.transition_agent, + forward_attn_mask=c.forward_attn_mask, location_attn=c.location_attn, separate_stopnet=c.separate_stopnet) elif c.model.lower() == "tacotron2": @@ -272,6 +273,7 @@ def setup_model(num_chars, c): prenet_dropout=c.prenet_dropout, forward_attn=c.use_forward_attn, trans_agent=c.transition_agent, + forward_attn_mask=c.forward_attn_mask, location_attn=c.location_attn, separate_stopnet=c.separate_stopnet) return model \ No newline at end of file