mirror of https://github.com/coqui-ai/TTS.git
forward_attn_mask and config update
parent
127a6b68e0
commit
31fe02412c
|
@ -37,7 +37,6 @@
|
|||
"lr": 0.0001, // Initial learning rate. If Noam decay is active, maximum learning rate.
|
||||
"lr_decay": false, // if true, Noam learning rate decaying is applied through training.
|
||||
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
|
||||
"windowing": false, // Enables attention windowing. Used only in eval mode.
|
||||
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
||||
"attention_norm": "sigmoid", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
||||
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
|
||||
|
@ -45,11 +44,14 @@
|
|||
"use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
|
||||
"transition_agent": false, // ONLY TACOTRON2 - enable/disable transition agent of forward attention.
|
||||
"location_attn": false, // ONLY TACOTRON2 - enable_disable location sensitive attention. It is enabled for TACOTRON by default.
|
||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||
"stopnet": true, // Train stopnet predicting the end of synthesis.
|
||||
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
|
||||
"separate_stopnet": true, // Train stopnet seperately if 'stopnet==true'. It prevents stopnet loss to influence the rest of the model. It causes a better model, but it trains SLOWER.
|
||||
"tb_model_param_stats": false, // true, plots param stats per layer on tensorboard. Might be memory consuming, but good for debugging.
|
||||
|
||||
"windowing": false, // Enables attention windowing. Used only in eval mode.
|
||||
"forward_attn_masking": false, // Enable forward attention masking which improves attention stability. Use it if network does not work as you like when it is off.
|
||||
|
||||
"batch_size": 32, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||
"eval_batch_size":16,
|
||||
|
|
|
@ -108,7 +108,7 @@ class Attention(nn.Module):
|
|||
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
||||
location_attention, attention_location_n_filters,
|
||||
attention_location_kernel_size, windowing, norm, forward_attn,
|
||||
trans_agent):
|
||||
trans_agent, forward_attn_mask):
|
||||
super(Attention, self).__init__()
|
||||
self.query_layer = Linear(
|
||||
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
|
||||
|
@ -128,6 +128,7 @@ class Attention(nn.Module):
|
|||
self.norm = norm
|
||||
self.forward_attn = forward_attn
|
||||
self.trans_agent = trans_agent
|
||||
self.forward_attn_mask = forward_attn_mask
|
||||
self.location_attention = location_attention
|
||||
|
||||
def init_win_idx(self):
|
||||
|
@ -203,9 +204,10 @@ class Attention(nn.Module):
|
|||
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
|
||||
self.u * prev_alpha) + 1e-8) * alignment
|
||||
# force incremental alignment - TODO: make configurable
|
||||
if not self.training:
|
||||
if not self.training and self.forward_attn_mask:
|
||||
_, n = prev_alpha.max(1)
|
||||
val, n2 = alpha.max(1)
|
||||
print(True)
|
||||
for b in range(alignment.shape[0]):
|
||||
alpha[b, n[b] + 2:] = 0
|
||||
alpha[b, :(n[b] - 1)] = 0 # ignore all previous states to prevent repetition.
|
||||
|
|
|
@ -271,7 +271,7 @@ class Decoder(nn.Module):
|
|||
|
||||
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing,
|
||||
attn_norm, prenet_type, prenet_dropout, forward_attn,
|
||||
trans_agent, location_attn, separate_stopnet):
|
||||
trans_agent, forward_attn_mask, location_attn, separate_stopnet):
|
||||
super(Decoder, self).__init__()
|
||||
self.r = r
|
||||
self.in_features = in_features
|
||||
|
@ -296,7 +296,8 @@ class Decoder(nn.Module):
|
|||
windowing=attn_windowing,
|
||||
norm=attn_norm,
|
||||
forward_attn=forward_attn,
|
||||
trans_agent=trans_agent)
|
||||
trans_agent=trans_agent,
|
||||
forward_attn_mask=forward_attn_mask)
|
||||
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
|
||||
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
|
||||
# decoder_RNN_input -> |RNN| -> RNN_state
|
||||
|
|
|
@ -97,7 +97,7 @@ class Encoder(nn.Module):
|
|||
class Decoder(nn.Module):
|
||||
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm,
|
||||
prenet_type, prenet_dropout, forward_attn, trans_agent,
|
||||
location_attn, separate_stopnet):
|
||||
forward_attn_mask, location_attn, separate_stopnet):
|
||||
super(Decoder, self).__init__()
|
||||
self.mel_channels = inputs_dim
|
||||
self.r = r
|
||||
|
@ -118,9 +118,17 @@ class Decoder(nn.Module):
|
|||
self.attention_rnn = nn.LSTMCell(self.prenet_dim + in_features,
|
||||
self.attention_rnn_dim)
|
||||
|
||||
self.attention_layer = Attention(self.attention_rnn_dim, in_features,
|
||||
128, location_attn, 32, 31, attn_win,
|
||||
attn_norm, forward_attn, trans_agent)
|
||||
self.attention_layer = Attention(attention_rnn_dim=self.attention_rnn_dim,
|
||||
embedding_dim=in_features,
|
||||
attention_dim=128,
|
||||
location_attention=location_attn,
|
||||
attention_location_n_filters=32,
|
||||
attention_location_kernel_size=31,
|
||||
windowing=attn_win,
|
||||
norm=attn_norm,
|
||||
forward_attn=forward_attn,
|
||||
trans_agent=trans_agent,
|
||||
forward_attn_mask=forward_attn_mask)
|
||||
|
||||
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
|
||||
self.decoder_rnn_dim, 1)
|
||||
|
|
|
@ -19,6 +19,7 @@ class Tacotron(nn.Module):
|
|||
prenet_dropout=True,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
separate_stopnet=True):
|
||||
super(Tacotron, self).__init__()
|
||||
|
@ -30,8 +31,8 @@ class Tacotron(nn.Module):
|
|||
self.encoder = Encoder(256)
|
||||
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_win,
|
||||
attn_norm, prenet_type, prenet_dropout,
|
||||
forward_attn, trans_agent, location_attn,
|
||||
separate_stopnet)
|
||||
forward_attn, trans_agent, forward_attn_mask,
|
||||
location_attn, separate_stopnet)
|
||||
self.postnet = PostCBHG(mel_dim)
|
||||
self.last_linear = nn.Sequential(
|
||||
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),
|
||||
|
|
|
@ -18,6 +18,7 @@ class Tacotron2(nn.Module):
|
|||
prenet_dropout=True,
|
||||
forward_attn=False,
|
||||
trans_agent=False,
|
||||
forward_attn_mask=False,
|
||||
location_attn=True,
|
||||
separate_stopnet=True):
|
||||
super(Tacotron2, self).__init__()
|
||||
|
@ -30,8 +31,8 @@ class Tacotron2(nn.Module):
|
|||
self.encoder = Encoder(512)
|
||||
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win,
|
||||
attn_norm, prenet_type, prenet_dropout,
|
||||
forward_attn, trans_agent, location_attn,
|
||||
separate_stopnet)
|
||||
forward_attn, trans_agent, forward_attn_mask,
|
||||
location_attn, separate_stopnet)
|
||||
self.postnet = Postnet(self.n_mel_channels)
|
||||
|
||||
def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):
|
||||
|
|
|
@ -260,6 +260,7 @@ def setup_model(num_chars, c):
|
|||
prenet_dropout=c.prenet_dropout,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
separate_stopnet=c.separate_stopnet)
|
||||
elif c.model.lower() == "tacotron2":
|
||||
|
@ -272,6 +273,7 @@ def setup_model(num_chars, c):
|
|||
prenet_dropout=c.prenet_dropout,
|
||||
forward_attn=c.use_forward_attn,
|
||||
trans_agent=c.transition_agent,
|
||||
forward_attn_mask=c.forward_attn_mask,
|
||||
location_attn=c.location_attn,
|
||||
separate_stopnet=c.separate_stopnet)
|
||||
return model
|
Loading…
Reference in New Issue