Merge branch 'attention-smoothing' into attn-smoothing-bgs-sigmoid-wd

pull/10/head
Eren 2018-09-26 16:51:42 +02:00
commit 006354320e
1 changed files with 2 additions and 1 deletions

View File

@ -131,7 +131,8 @@ class AttentionRNNCell(nn.Module):
mask = mask.view(memory.size(0), -1)
alignment.masked_fill_(1 - mask, -float("inf"))
# Normalize context weight
alignment = F.softmax(alignment, dim=-1)
# alignment = F.softmax(alignment, dim=-1)
alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
# Attention context vector
# (batch, 1, dim)
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j