diff --git a/layers/common_layers.py b/layers/common_layers.py index d939fbc9..3f694463 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -207,9 +207,9 @@ class Attention(nn.Module): _, n = prev_alpha.max(1) val, n2 = alpha.max(1) for b in range(alignment.shape[0]): - alpha[b, n + 2:] = 0 - alpha[b, :(n - 1)] = 0 # ignore all previous states to prevent repetition. - alpha[b, (n - 2)] = 0.01 * val # smoothing factor for the prev step + alpha[b, n[b] + 2:] = 0 + alpha[b, :(n[b] - 1)] = 0 # ignore all previous states to prevent repetition. + alpha[b, (n[b] - 2)] = 0.01 * val[b] # smoothing factor for the prev step # compute attention weights self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1) # compute context diff --git a/requirements.txt b/requirements.txt index 2e145a8f..6a7a446f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,5 @@ matplotlib==2.0.2 Pillow flask scipy==0.19.0 -lws tqdm git+git://github.com/bootphon/phonemizer@master