forward_attn_mask and config update

pull/10/head
Eren Golge 2019-06-04 00:39:29 +02:00
parent 127a6b68e0
commit 31fe02412c
7 changed files with 32 additions and 15 deletions

View File

@ -37,7 +37,6 @@
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"windowing": false, // Enables attention windowing. Used only in eval mode.
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
"attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
@ -45,11 +44,14 @@
"use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
"transition_agent": false, // ONLY TACOTRON2 - enable/disable transition agent of forward attention.
"location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default.
"loss_masking": true, // enable / disable loss masking against the sequence padding.
"loss_masking": true, // enable / disable loss masking against the sequence padding.
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
"stopnet": true, // Train stopnet predicting the end of synthesis.
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
"windowing": false, // Enables attention windowing. Used only in eval mode.
"forward_attn_masking": false, // Enable forward attention masking which improves attention stability. Use it if network does not work as you like when it is off.
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"eval_batch_size":16,

View File

@ -108,7 +108,7 @@ class Attention(nn.Module):
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
location_attention, attention_location_n_filters,
attention_location_kernel_size, windowing, norm, forward_attn,
trans_agent):
trans_agent, forward_attn_mask):
super(Attention, self).__init__()
self.query_layer = Linear(
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
@ -128,6 +128,7 @@ class Attention(nn.Module):
self.norm = norm
self.forward_attn = forward_attn
self.trans_agent = trans_agent
self.forward_attn_mask = forward_attn_mask
self.location_attention = location_attention
def init_win_idx(self):
@ -203,9 +204,10 @@ class Attention(nn.Module):
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
self.u * prev_alpha) + 1e-8) * alignment
# force incremental alignment - TODO: make configurable
if not self.training:
if not self.training and self.forward_attn_mask:
_, n = prev_alpha.max(1)
val, n2 = alpha.max(1)
print(True)
for b in range(alignment.shape[0]):
alpha[b, n[b] + 2:] = 0
alpha[b, :(n[b] - 1)] = 0 # ignore all previous states to prevent repetition.

View File

@ -271,7 +271,7 @@ class Decoder(nn.Module):
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing,
attn_norm, prenet_type, prenet_dropout, forward_attn,
trans_agent, location_attn, separate_stopnet):
trans_agent, forward_attn_mask, location_attn, separate_stopnet):
super(Decoder, self).__init__()
self.r = r
self.in_features = in_features
@ -296,7 +296,8 @@ class Decoder(nn.Module):
windowing=attn_windowing,
norm=attn_norm,
forward_attn=forward_attn,
trans_agent=trans_agent)
trans_agent=trans_agent,
forward_attn_mask=forward_attn_mask)
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
# decoder_RNN_input -> |RNN| -> RNN_state

View File

@ -97,7 +97,7 @@ class Encoder(nn.Module):
class Decoder(nn.Module):
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm,
prenet_type, prenet_dropout, forward_attn, trans_agent,
location_attn, separate_stopnet):
forward_attn_mask, location_attn, separate_stopnet):
super(Decoder, self).__init__()
self.mel_channels = inputs_dim
self.r = r
@ -118,9 +118,17 @@ class Decoder(nn.Module):
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
self.attention_rnn_dim)
self.attention_layer = Attention(self.attention_rnn_dim, in_features,
128, location_attn, 32, 31, attn_win,
attn_norm, forward_attn, trans_agent)
self.attention_layer = Attention(attention_rnn_dim=self.attention_rnn_dim,
embedding_dim=in_features,
attention_dim=128,
location_attention=location_attn,
attention_location_n_filters=32,
attention_location_kernel_size=31,
windowing=attn_win,
norm=attn_norm,
forward_attn=forward_attn,
trans_agent=trans_agent,
forward_attn_mask=forward_attn_mask)
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
self.decoder_rnn_dim, 1)

View File

@ -19,6 +19,7 @@ class Tacotron(nn.Module):
prenet_dropout=True,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
separate_stopnet=True):
super(Tacotron, self).__init__()
@ -30,8 +31,8 @@ class Tacotron(nn.Module):
self.encoder = Encoder(256)
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win,
attn_norm, prenet_type, prenet_dropout,
forward_attn, trans_agent, location_attn,
separate_stopnet)
forward_attn, trans_agent, forward_attn_mask,
location_attn, separate_stopnet)
self.postnet = PostCBHG(mel_dim)
self.last_linear = nn.Sequential(
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),

View File

@ -18,6 +18,7 @@ class Tacotron2(nn.Module):
prenet_dropout=True,
forward_attn=False,
trans_agent=False,
forward_attn_mask=False,
location_attn=True,
separate_stopnet=True):
super(Tacotron2, self).__init__()
@ -30,8 +31,8 @@ class Tacotron2(nn.Module):
self.encoder = Encoder(512)
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win,
attn_norm, prenet_type, prenet_dropout,
forward_attn, trans_agent, location_attn,
separate_stopnet)
forward_attn, trans_agent, forward_attn_mask,
location_attn, separate_stopnet)
self.postnet = Postnet(self.n_mel_channels)
def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):

View File

@ -260,6 +260,7 @@ def setup_model(num_chars, c):
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
separate_stopnet=c.separate_stopnet)
elif c.model.lower() == "tacotron2":
@ -272,6 +273,7 @@ def setup_model(num_chars, c):
prenet_dropout=c.prenet_dropout,
forward_attn=c.use_forward_attn,
trans_agent=c.transition_agent,
forward_attn_mask=c.forward_attn_mask,
location_attn=c.location_attn,
separate_stopnet=c.separate_stopnet)
return model