From 0f0ec679ec0f7c24579f6f8c063dd00f8d4c8b8d Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 16 Jul 2019 21:15:24 +0200 Subject: [PATCH] small refactoring --- layers/common_layers.py | 29 +++++++++++++++++++---------- layers/tacotron.py | 1 + 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/layers/common_layers.py b/layers/common_layers.py index f7b8e7ed..c84b04b9 100644 --- a/layers/common_layers.py +++ b/layers/common_layers.py @@ -84,15 +84,17 @@ class Prenet(nn.Module): class LocationLayer(nn.Module): - def __init__(self, attention_n_filters, attention_kernel_size, - attention_dim): + def __init__(self, + attention_dim, + attention_n_filters=32, + attention_kernel_size=31): super(LocationLayer, self).__init__() self.location_conv = nn.Conv1d( in_channels=2, out_channels=attention_n_filters, - kernel_size=31, + kernel_size=attention_kernel_size, stride=1, - padding=(31 - 1) // 2, + padding=(attention_kernel_size - 1) // 2, bias=False) self.location_dense = Linear( attention_n_filters, attention_dim, bias=False, init_gain='tanh') @@ -120,8 +122,10 @@ class Attention(nn.Module): attention_rnn_dim + embedding_dim, 1, bias=True) if location_attention: self.location_layer = LocationLayer( - attention_location_n_filters, attention_location_kernel_size, - attention_dim) + attention_dim, + attention_location_n_filters, + attention_location_kernel_size, + ) self._mask_value = -float("inf") self.windowing = windowing self.win_idx = None @@ -203,14 +207,18 @@ class Attention(nn.Module): # compute transition potentials alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-8) * alignment - # force incremental alignment - TODO: make configurable + # force incremental alignment if not self.training and self.forward_attn_mask: _, n = prev_alpha.max(1) val, n2 = alpha.max(1) for b in range(alignment.shape[0]): alpha[b, n[b] + 3:] = 0 - alpha[b, :(n[b] - 1)] = 0 # ignore all previous states to prevent repetition. - alpha[b, (n[b] - 2)] = 0.01 * val[b] # smoothing factor for the prev step + alpha[b, :( + n[b] - 1 + )] = 0 # ignore all previous states to prevent repetition. + alpha[b, + (n[b] - 2 + )] = 0.01 * val[b] # smoothing factor for the prev step # compute attention weights self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1) # compute context @@ -240,7 +248,8 @@ class Attention(nn.Module): alignment = torch.softmax(attention, dim=-1) elif self.norm == "sigmoid": alignment = torch.sigmoid(attention) / torch.sigmoid( - attention).sum(dim=1).unsqueeze(1) + attention).sum( + dim=1, keepdim=True) else: raise RuntimeError("Unknown value for attention norm type") if self.location_attention: diff --git a/layers/tacotron.py b/layers/tacotron.py index 8915f385..424f8479 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -268,6 +268,7 @@ class Decoder(nn.Module): memory_dim (int): memory vector (prev. time-step output) sample size. r (int): number of outputs per time step. memory_size (int): size of the past window. if <= 0 memory_size = r + TODO: arguments """ def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing,