Add a missing class variable to attention class

pull/10/head
Eren Golge 2018-05-23 06:20:04 -07:00
parent 8ffc85008a
commit fe99baec5a
1 changed files with 1 additions and 0 deletions

View File

@ -80,6 +80,7 @@ class AttentionRNN(nn.Module):
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
"""
super(AttentionRNN, self).__init__()
self.align_model = align_model
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, out_dim)
# pick bahdanau or location sensitive attention
if align_model == 'b':