loca sens attn fix

pull/10/head
Eren G 2018-07-17 17:43:51 +02:00
parent 4e6596a8e1
commit 4ef3ecf37f
2 changed files with 6 additions and 5 deletions

View File

@ -39,7 +39,7 @@ class LocationSensitiveAttention(nn.Module):
self.kernel_size = kernel_size
self.filters = filters
padding = int((kernel_size - 1) / 2)
self.loc_conv = nn.Conv1d(1, filters,
self.loc_conv = nn.Conv1d(2, filters,
kernel_size=kernel_size, stride=1,
padding=padding, bias=False)
self.loc_linear = nn.Linear(filters, attn_dim)
@ -77,15 +77,15 @@ class AttentionRNNCell(nn.Module):
out_dim (int): context vector feature dimension.
rnn_dim (int): rnn hidden state dimension.
annot_dim (int): annotation vector feature dimension.
memory_dim (int): memory vector (decoder autogression) feature dimension.
memory_dim (int): memory vector (decoder output) feature dimension.
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
"""
super(AttentionRNNCell, self).__init__()
self.align_model = align_model
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, rnn_dim)
self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim)
# pick bahdanau or location sensitive attention
if align_model == 'b':
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim, out_dim)
if align_model == 'ls':
self.alignment_model = LocationSensitiveAttention(annot_dim, rnn_dim, out_dim)
else:

View File

@ -198,6 +198,7 @@ class Decoder(nn.Module):
def __init__(self, in_features, memory_dim, r):
super(Decoder, self).__init__()
self.r = r
self.in_features = in_features
self.max_decoder_steps = 200
self.memory_dim = memory_dim
# memory -> |Prenet| -> processed_memory
@ -249,7 +250,7 @@ class Decoder(nn.Module):
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
for _ in range(len(self.decoder_rnns))]
current_context_vec = inputs.data.new(B, 128).zero_()
current_context_vec = inputs.data.new(B, self.in_features).zero_()
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
# attention states
attention = inputs.data.new(B, T).zero_()