From 44c66c6e3ebc4ab8bca4fa09804ffc16cf209502 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 5 Mar 2019 13:34:33 +0100 Subject: [PATCH] remove comments --- layers/attention.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/layers/attention.py b/layers/attention.py index ee18386e..c59ce406 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -146,12 +146,7 @@ class AttentionRNNCell(nn.Module): if t == 0: self.alignment_model.reset() self.win_idx = 0 - # Feed it to RNN - # s_i = f(y_{i-1}, c_{i}, s_{i-1}) rnn_output = self.rnn_cell(torch.cat((memory, context), -1), rnn_state) - # Alignment - # (batch, max_time) - # e_{ij} = a(s_{i-1}, h_j) if self.align_model is 'b': alignment = self.alignment_model(annots, rnn_output) else: @@ -169,13 +164,7 @@ class AttentionRNNCell(nn.Module): alignment[:, front_win:] = -float("inf") # Update the window self.win_idx = torch.argmax(alignment,1).long()[0].item() - # Normalize context weight - # alignment = F.softmax(alignment, dim=-1) - # alignment = 5 * alignment alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1) - # Attention context vector - # (batch, 1, dim) - # c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j context = torch.bmm(alignment.unsqueeze(1), annots) context = context.squeeze(1) return rnn_output, context, alignment