small refactoring

pull/10/head
Eren Golge 2019-07-16 21:15:24 +02:00
parent fd081c49b7
commit 0f0ec679ec
2 changed files with 20 additions and 10 deletions

View File

@ -84,15 +84,17 @@ class Prenet(nn.Module):
class LocationLayer(nn.Module):
def __init__(self, attention_n_filters, attention_kernel_size,
attention_dim):
def __init__(self,
attention_dim,
attention_n_filters=32,
attention_kernel_size=31):
super(LocationLayer, self).__init__()
self.location_conv = nn.Conv1d(
in_channels=2,
out_channels=attention_n_filters,
kernel_size=31,
kernel_size=attention_kernel_size,
stride=1,
padding=(31 - 1) // 2,
padding=(attention_kernel_size - 1) // 2,
bias=False)
self.location_dense = Linear(
attention_n_filters, attention_dim, bias=False, init_gain='tanh')
@ -120,8 +122,10 @@ class Attention(nn.Module):
attention_rnn_dim + embedding_dim, 1, bias=True)
if location_attention:
self.location_layer = LocationLayer(
attention_location_n_filters, attention_location_kernel_size,
attention_dim)
attention_dim,
attention_location_n_filters,
attention_location_kernel_size,
)
self._mask_value = -float("inf")
self.windowing = windowing
self.win_idx = None
@ -203,14 +207,18 @@ class Attention(nn.Module):
# compute transition potentials
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
self.u * prev_alpha) + 1e-8) * alignment
# force incremental alignment - TODO: make configurable
# force incremental alignment
if not self.training and self.forward_attn_mask:
_, n = prev_alpha.max(1)
val, n2 = alpha.max(1)
for b in range(alignment.shape[0]):
alpha[b, n[b] + 3:] = 0
alpha[b, :(n[b] - 1)] = 0 # ignore all previous states to prevent repetition.
alpha[b, (n[b] - 2)] = 0.01 * val[b] # smoothing factor for the prev step
alpha[b, :(
n[b] - 1
)] = 0 # ignore all previous states to prevent repetition.
alpha[b,
(n[b] - 2
)] = 0.01 * val[b] # smoothing factor for the prev step
# compute attention weights
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
# compute context
@ -240,7 +248,8 @@ class Attention(nn.Module):
alignment = torch.softmax(attention, dim=-1)
elif self.norm == "sigmoid":
alignment = torch.sigmoid(attention) / torch.sigmoid(
attention).sum(dim=1).unsqueeze(1)
attention).sum(
dim=1, keepdim=True)
else:
raise RuntimeError("Unknown value for attention norm type")
if self.location_attention:

View File

@ -268,6 +268,7 @@ class Decoder(nn.Module):
memory_dim (int): memory vector (prev. time-step output) sample size.
r (int): number of outputs per time step.
memory_size (int): size of the past window. if <= 0 memory_size = r
TODO: arguments
"""
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing,