Cache attention annot vectors for the whole sequence.

pull/10/head
Eren Golge 2018-12-11 16:06:02 +01:00
parent 211a20a47a
commit dc3d09304e
2 changed files with 14 additions and 5 deletions

View File

@ -56,6 +56,7 @@ class LocationSensitiveAttention(nn.Module):
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
self.v = nn.Linear(attn_dim, 1, bias=False)
self.processed_annots = None
# self.init_layers()
def init_layers(self):
@ -72,6 +73,9 @@ class LocationSensitiveAttention(nn.Module):
self.v.weight,
gain=torch.nn.init.calculate_gain('linear'))
def reset(self):
self.processed_annots = None
def forward(self, annot, query, loc):
"""
Shapes:
@ -86,9 +90,11 @@ class LocationSensitiveAttention(nn.Module):
loc_conv = loc_conv.transpose(1, 2)
processed_loc = self.loc_linear(loc_conv)
processed_query = self.query_layer(query)
processed_annots = self.annot_layer(annot)
# cache annots
if self.processed_annots is None:
self.processed_annots = self.annot_layer(annot)
alignment = self.v(
torch.tanh(processed_query + processed_annots + processed_loc))
torch.tanh(processed_query + self.processed_annots + processed_loc))
# (batch, max_time)
return alignment.squeeze(-1)
@ -120,7 +126,7 @@ class AttentionRNNCell(nn.Module):
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(
align_model))
def forward(self, memory, context, rnn_state, annots, atten, mask):
def forward(self, memory, context, rnn_state, annots, atten, mask, t):
"""
Shapes:
- memory: (batch, 1, dim) or (batch, dim)
@ -130,6 +136,8 @@ class AttentionRNNCell(nn.Module):
- atten: (batch, 2, max_time)
- mask: (batch,)
"""
if t == 0:
self.alignment_model.reset()
# Concat input query and previous context context
rnn_input = torch.cat((memory, context), -1)
# Feed it to RNN
@ -147,7 +155,8 @@ class AttentionRNNCell(nn.Module):
alignment.masked_fill_(1 - mask, -float("inf"))
# Normalize context weight
# alignment = F.softmax(alignment, dim=-1)
alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
# alignment = 5 * alignment
alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
# Attention context vector
# (batch, 1, dim)
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j

View File

@ -402,7 +402,7 @@ class Decoder(nn.Module):
(attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1)
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
processed_memory, current_context_vec, attention_rnn_hidden,
inputs, attention_cat, mask)
inputs, attention_cat, mask, t)
attention_cum += attention
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(