From 2586be7d3347d2672af10852ff2fe59190da41ed Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Fri, 24 May 2019 13:40:56 +0200 Subject: [PATCH] batch-wise operation --- layers/tacotron2.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/layers/tacotron2.py b/layers/tacotron2.py index b051634e..daea2bd8 100644 --- a/layers/tacotron2.py +++ b/layers/tacotron2.py @@ -222,7 +222,11 @@ class Attention(nn.Module): # force incremental alignment if not self.training: val, n = prev_alpha.max(1) - alignment[:, n+2 :] = 0 + if alignment.shape[0] == 1: + alignment[:, n+2:] = 0 + else: + for b in range(alignment.shape[0]): + alignment[b, n[b]+2:] alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-8) * alignment self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)