use sigmoid for attention

pull/10/head
Eren Golge 2019-01-16 16:26:05 +01:00
parent 7e020d4084
commit 4431e04b48
1 changed files with 4 additions and 4 deletions

View File

@ -170,9 +170,9 @@ class AttentionRNNCell(nn.Module):
# Update the window
self.win_idx = torch.argmax(alignment,1).long()[0].item()
# Normalize context weight
alignment = F.softmax(alignment, dim=-1)
# alignment = F.softmax(alignment, dim=-1)
# alignment = 5 * alignment
# alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
# Attention context vector
# (batch, 1, dim)
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j