use sigmoid for attention

pull/10/head
Eren Golge 2019-01-16 16:26:05 +01:00
parent 7e020d4084
commit 4431e04b48
1 changed files with 4 additions and 4 deletions

View File

@ -167,12 +167,12 @@ class AttentionRNNCell(nn.Module):
alignment[:, :back_win] = -float("inf")
if front_win < memory.shape[1]:
alignment[:, front_win:] = -float("inf")
# Update the window
self.win_idx = torch.argmax(alignment,1).long()[0].item()
# Update the window
self.win_idx = torch.argmax(alignment,1).long()[0].item()
# Normalize context weight
alignment = F.softmax(alignment, dim=-1)
# alignment = F.softmax(alignment, dim=-1)
# alignment = 5 * alignment
# alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
# Attention context vector
# (batch, 1, dim)
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j