mirror of https://github.com/coqui-ai/TTS.git
Cache attention annot vectors for the whole sequence.
parent
211a20a47a
commit
dc3d09304e
|
@ -56,6 +56,7 @@ class LocationSensitiveAttention(nn.Module):
|
|||
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
|
||||
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
|
||||
self.v = nn.Linear(attn_dim, 1, bias=False)
|
||||
self.processed_annots = None
|
||||
# self.init_layers()
|
||||
|
||||
def init_layers(self):
|
||||
|
@ -72,6 +73,9 @@ class LocationSensitiveAttention(nn.Module):
|
|||
self.v.weight,
|
||||
gain=torch.nn.init.calculate_gain('linear'))
|
||||
|
||||
def reset(self):
|
||||
self.processed_annots = None
|
||||
|
||||
def forward(self, annot, query, loc):
|
||||
"""
|
||||
Shapes:
|
||||
|
@ -86,9 +90,11 @@ class LocationSensitiveAttention(nn.Module):
|
|||
loc_conv = loc_conv.transpose(1, 2)
|
||||
processed_loc = self.loc_linear(loc_conv)
|
||||
processed_query = self.query_layer(query)
|
||||
processed_annots = self.annot_layer(annot)
|
||||
# cache annots
|
||||
if self.processed_annots is None:
|
||||
self.processed_annots = self.annot_layer(annot)
|
||||
alignment = self.v(
|
||||
torch.tanh(processed_query + processed_annots + processed_loc))
|
||||
torch.tanh(processed_query + self.processed_annots + processed_loc))
|
||||
# (batch, max_time)
|
||||
return alignment.squeeze(-1)
|
||||
|
||||
|
@ -120,7 +126,7 @@ class AttentionRNNCell(nn.Module):
|
|||
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(
|
||||
align_model))
|
||||
|
||||
def forward(self, memory, context, rnn_state, annots, atten, mask):
|
||||
def forward(self, memory, context, rnn_state, annots, atten, mask, t):
|
||||
"""
|
||||
Shapes:
|
||||
- memory: (batch, 1, dim) or (batch, dim)
|
||||
|
@ -130,6 +136,8 @@ class AttentionRNNCell(nn.Module):
|
|||
- atten: (batch, 2, max_time)
|
||||
- mask: (batch,)
|
||||
"""
|
||||
if t == 0:
|
||||
self.alignment_model.reset()
|
||||
# Concat input query and previous context context
|
||||
rnn_input = torch.cat((memory, context), -1)
|
||||
# Feed it to RNN
|
||||
|
@ -147,7 +155,8 @@ class AttentionRNNCell(nn.Module):
|
|||
alignment.masked_fill_(1 - mask, -float("inf"))
|
||||
# Normalize context weight
|
||||
# alignment = F.softmax(alignment, dim=-1)
|
||||
alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
|
||||
# alignment = 5 * alignment
|
||||
alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
|
||||
# Attention context vector
|
||||
# (batch, 1, dim)
|
||||
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
|
||||
|
|
|
@ -402,7 +402,7 @@ class Decoder(nn.Module):
|
|||
(attention.unsqueeze(1), attention_cum.unsqueeze(1)), dim=1)
|
||||
attention_rnn_hidden, current_context_vec, attention = self.attention_rnn(
|
||||
processed_memory, current_context_vec, attention_rnn_hidden,
|
||||
inputs, attention_cat, mask)
|
||||
inputs, attention_cat, mask, t)
|
||||
attention_cum += attention
|
||||
# Concat RNN output and attention context vector
|
||||
decoder_input = self.project_to_decoder_in(
|
||||
|
|
Loading…
Reference in New Issue