diff --git a/layers/attention.py b/layers/attention.py index 5436f110..534e4ba4 100644 --- a/layers/attention.py +++ b/layers/attention.py @@ -56,6 +56,7 @@ class LocationSensitiveAttention(nn.Module): self.query_layer = nn.Linear(query_dim, attn_dim, bias=True) self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True) self.v = nn.Linear(attn_dim, 1, bias=False) + self.processed_annots = None # self.init_layers() def init_layers(self): @@ -72,6 +73,9 @@ class LocationSensitiveAttention(nn.Module): self.v.weight, gain=torch.nn.init.calculate_gain('linear')) + def reset(self): + self.processed_annots = None + def forward(self, annot, query, loc): """ Shapes: @@ -86,9 +90,11 @@ class LocationSensitiveAttention(nn.Module): loc_conv = loc_conv.transpose(1, 2) processed_loc = self.loc_linear(loc_conv) processed_query = self.query_layer(query) - processed_annots = self.annot_layer(annot) + # cache annots + if self.processed_annots is None: + self.processed_annots = self.annot_layer(annot) alignment = self.v( - torch.tanh(processed_query + processed_annots + processed_loc)) + torch.tanh(processed_query + self.processed_annots + processed_loc)) # (batch, max_time) return alignment.squeeze(-1) @@ -120,7 +126,7 @@ class AttentionRNNCell(nn.Module): 'b' (Bahdanau) or 'ls' (Location Sensitive).".format( align_model)) - def forward(self, memory, context, rnn_state, annots, atten, mask): + def forward(self, memory, context, rnn_state, annots, atten, mask, t): """ Shapes: - memory: (batch, 1, dim) or (batch, dim) @@ -130,6 +136,8 @@ class AttentionRNNCell(nn.Module): - atten: (batch, 2, max_time) - mask: (batch,) """ + if t == 0: + self.alignment_model.reset() # Concat input query and previous context context rnn_input = torch.cat((memory, context), -1) # Feed it to RNN @@ -147,7 +155,8 @@ class AttentionRNNCell(nn.Module): alignment.masked_fill_(1 - mask, -float("inf")) # Normalize context weight # alignment = F.softmax(alignment, dim=-1) - alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1) + # alignment = 5 * alignment + alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1) # Attention context vector # (batch, 1, dim) # c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j diff --git a/layers/tacotron.py b/layers/tacotron.py index f2077106..3e82486f 100644 --- a/layers/tacotron.py +++ b/layers/tacotron.py @@ -402,7 +402,7 @@ class Decoder(nn.Module): (attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1) attention_rnn_hidden, current_context_vec, attention = self.attention_rnn( processed_memory, current_context_vec, attention_rnn_hidden, - inputs, attention_cat, mask) + inputs, attention_cat, mask, t) attention_cum += attention # Concat RNN output and attention context vector decoder_input = self.project_to_decoder_in(