Enable optional forward attention with transition agent

pull/10/head
Eren Golge 2019-04-10 16:41:30 +02:00
parent e2cf35bb10
commit 312a539a0e
4 changed files with 46 additions and 29 deletions

View File

@ -41,7 +41,9 @@
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
"attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
"prenet_type": "bn", // ONLY TACOTRON2 - "original" or "bn".
"use_forward_attn": false, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
"use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
"transition_agent": true, // ONLY TACOTRON2 - enable/disable transition agent of forward attention.
"loss_masking": false, // enable / disable loss masking against the sequence padding.
"batch_size": 16, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"eval_batch_size":16,

View File

@ -42,6 +42,8 @@
"attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
"use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
"transition_agent": true, // ONLY TACOTRON2 - enable/disable transition agent of forward attention.
"loss_masking": false, // enable / disable loss masking against the sequence padding.
"batch_size": 16, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"eval_batch_size":16,

View File

@ -122,13 +122,15 @@ class LocationLayer(nn.Module):
class Attention(nn.Module):
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
attention_location_n_filters, attention_location_kernel_size,
windowing, norm, forward_attn):
windowing, norm, forward_attn, trans_agent):
super(Attention, self).__init__()
self.query_layer = Linear(
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
self.inputs_layer = Linear(
embedding_dim, attention_dim, bias=False, init_gain='tanh')
self.v = Linear(attention_dim, 1, bias=True)
if trans_agent:
self.ta = nn.Linear(attention_dim + embedding_dim, 1, bias=True)
self.location_layer = LocationLayer(attention_location_n_filters,
attention_location_kernel_size,
attention_dim)
@ -137,6 +139,7 @@ class Attention(nn.Module):
self.win_idx = None
self.norm = norm
self.forward_attn = forward_attn
self.trans_agent = trans_agent
def init_win_idx(self):
self.win_idx = -1
@ -160,29 +163,46 @@ class Attention(nn.Module):
processed_inputs))
energies = energies.squeeze(-1)
return energies
return energies, processed_query
def apply_windowing(self, attention):
back_win = self.win_idx - self.win_back
front_win = self.win_idx + self.win_front
if back_win > 0:
attention[:, :back_win] = -float("inf")
if front_win < inputs.shape[1]:
attention[:, front_win:] = -float("inf")
# this is a trick to solve a special problem.
# but it does not hurt.
if self.win_idx == -1:
attention[:, 0] = attention.max()
# Update the window
self.win_idx = torch.argmax(attention, 1).long()[0].item()
return attention
def apply_forward_attention(self, inputs, alignment, processed_query):
# forward attention
prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device)
self.alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-7) * alignment
alpha_norm = self.alpha / self.alpha.sum(dim=1).unsqueeze(1)
# compute context
context = torch.bmm(alpha_norm.unsqueeze(1), inputs)
context = context.squeeze(1)
# compute transition agent
if self.trans_agent:
ta_input = torch.cat([context, processed_query.squeeze(1)], dim=-1)
self.u = torch.sigmoid(self.ta(ta_input))
return context, alpha_norm, alignment
def forward(self, attention_hidden_state, inputs, processed_inputs,
attention_cat, mask):
attention = self.get_attention(
attention, processed_query = self.get_attention(
attention_hidden_state, processed_inputs, attention_cat)
if mask is not None:
attention.data.masked_fill_(1 - mask, self._mask_value)
# Windowing
if not self.training and self.windowing:
back_win = self.win_idx - self.win_back
front_win = self.win_idx + self.win_front
if back_win > 0:
attention[:, :back_win] = -float("inf")
if front_win < inputs.shape[1]:
attention[:, front_win:] = -float("inf")
# this is a trick to solve a special problem.
# but it does not hurt.
if self.win_idx == -1:
attention[:, 0] = attention.max()
# Update the window
self.win_idx = torch.argmax(attention, 1).long()[0].item()
attention = self.apply_windowing(attention)
if self.norm == "softmax":
alignment = torch.softmax(attention, dim=-1)
elif self.norm == "sigmoid":
@ -191,14 +211,7 @@ class Attention(nn.Module):
else:
raise RuntimeError("Unknown value for attention norm type")
if self.forward_attn:
# forward attention
prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device)
self.alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-7) * alignment
alpha_norm = self.alpha / self.alpha.sum(dim=1).unsqueeze(1)
# compute context
context = torch.bmm(alpha_norm.unsqueeze(1), inputs)
context = context.squeeze(1)
return context, alpha_norm, alignment
return self.apply_forward_attention(inputs, alignment, processed_query)
else:
context = torch.bmm(alignment.unsqueeze(1), inputs)
context = context.squeeze(1)
@ -272,7 +285,7 @@ class Encoder(nn.Module):
# adapted from https://github.com/NVIDIA/tacotron2/
class Decoder(nn.Module):
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, forward_attn):
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent):
super(Decoder, self).__init__()
self.mel_channels = inputs_dim
self.r = r
@ -292,7 +305,7 @@ class Decoder(nn.Module):
self.attention_rnn_dim)
self.attention_layer = Attention(self.attention_rnn_dim, in_features,
128, 32, 31, attn_win, attn_norm, forward_attn)
128, 32, 31, attn_win, attn_norm, forward_attn, trans_agent)
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
self.decoder_rnn_dim, 1)

View File

@ -9,7 +9,7 @@ from utils.generic_utils import sequence_mask
# TODO: match function arguments with tacotron
class Tacotron2(nn.Module):
def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", forward_attn=False):
def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", forward_attn=False, trans_agent=False):
super(Tacotron2, self).__init__()
self.n_mel_channels = 80
self.n_frames_per_step = r
@ -18,7 +18,7 @@ class Tacotron2(nn.Module):
val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val)
self.encoder = Encoder(512)
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, forward_attn)
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent)
self.postnet = Postnet(self.n_mel_channels)
def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):