batch-wise operation

pull/10/head
Eren Golge 2019-05-24 13:40:56 +02:00
parent 3a4a3e571a
commit 2586be7d33
1 changed files with 5 additions and 1 deletions

View File

@ -222,7 +222,11 @@ class Attention(nn.Module):
# force incremental alignment
if not self.training:
val, n = prev_alpha.max(1)
if alignment.shape[0] == 1:
alignment[:, n+2:] = 0
else:
for b in range(alignment.shape[0]):
alignment[b, n[b]+2:]
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
self.u * prev_alpha) + 1e-8) * alignment
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)