use sigmoid for attention

pull/10/head
Eren Golge 2019-01-16 16:26:05 +01:00
parent 7e020d4084
commit 4431e04b48
1 changed files with 4 additions and 4 deletions

View File

@ -167,12 +167,12 @@ class AttentionRNNCell(nn.Module):
alignment[:, :back_win] = -float("inf") alignment[:, :back_win] = -float("inf")
if front_win < memory.shape[1]: if front_win < memory.shape[1]:
alignment[:, front_win:] = -float("inf") alignment[:, front_win:] = -float("inf")
# Update the window # Update the window
self.win_idx = torch.argmax(alignment,1).long()[0].item() self.win_idx = torch.argmax(alignment,1).long()[0].item()
# Normalize context weight # Normalize context weight
alignment = F.softmax(alignment, dim=-1) # alignment = F.softmax(alignment, dim=-1)
# alignment = 5 * alignment # alignment = 5 * alignment
# alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1) alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
# Attention context vector # Attention context vector
# (batch, 1, dim) # (batch, 1, dim)
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j # c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j