diff --git a/layers/attention.py b/layers/attention.py index 6b9ee47b..f9a36e66 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -80,6 +80,7 @@ class AttentionRNN(nn.Module): align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment. """ super(AttentionRNN, self).__init__() + self.align_model = align_model self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim) # pick bahdanau or location sensitive attention if align_model == 'b':