Look for the last two attention values for stop condition and attend to the first encoder verctor if it is the first decoder iteration

pull/10/head
Eren Golge 2019-03-06 23:46:02 +01:00
parent d8add67a2f
commit 4b116a2a88
1 changed files with 6 additions and 2 deletions

View File

@ -105,7 +105,7 @@ class Attention(nn.Module):
self.win_idx = None
def init_win_idx(self):
self.win_idx = 0
self.win_idx = -1
def get_attention(self, query, processed_inputs, attention_cat):
processed_query = self.query_layer(query.unsqueeze(1))
@ -132,6 +132,10 @@ class Attention(nn.Module):
attention[:, :back_win] = -float("inf")
if front_win < inputs.shape[1]:
attention[:, front_win:] = -float("inf")
# this is a trick to solve a special problem.
# but it does not hurt.
if self.win_idx == -1:
attention[:, 0] = attention.max()
# Update the window
self.win_idx = torch.argmax(attention, 1).long()[0].item()
alignment = torch.sigmoid(attention) / torch.sigmoid(
@ -355,7 +359,7 @@ class Decoder(nn.Module):
alignments += [alignment]
stop_flags[0] = stop_flags[0] or gate_output > 0.5
stop_flags[1] = stop_flags[1] or alignment[0, -3:].sum() > 0.5
stop_flags[1] = stop_flags[1] or alignment[0, -2:].sum() > 0.5
if all(stop_flags):
break
elif len(outputs) == self.max_decoder_steps: