mirror of https://github.com/coqui-ai/TTS.git
remove comments
parent
1e8fdec084
commit
44c66c6e3e
|
@ -146,12 +146,7 @@ class AttentionRNNCell(nn.Module):
|
|||
if t == 0:
|
||||
self.alignment_model.reset()
|
||||
self.win_idx = 0
|
||||
# Feed it to RNN
|
||||
# s_i = f(y_{i-1}, c_{i}, s_{i-1})
|
||||
rnn_output = self.rnn_cell(torch.cat((memory, context), -1), rnn_state)
|
||||
# Alignment
|
||||
# (batch, max_time)
|
||||
# e_{ij} = a(s_{i-1}, h_j)
|
||||
if self.align_model is 'b':
|
||||
alignment = self.alignment_model(annots, rnn_output)
|
||||
else:
|
||||
|
@ -169,13 +164,7 @@ class AttentionRNNCell(nn.Module):
|
|||
alignment[:, front_win:] = -float("inf")
|
||||
# Update the window
|
||||
self.win_idx = torch.argmax(alignment,1).long()[0].item()
|
||||
# Normalize context weight
|
||||
# alignment = F.softmax(alignment, dim=-1)
|
||||
# alignment = 5 * alignment
|
||||
alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
|
||||
# Attention context vector
|
||||
# (batch, 1, dim)
|
||||
# c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j
|
||||
context = torch.bmm(alignment.unsqueeze(1), annots)
|
||||
context = context.squeeze(1)
|
||||
return rnn_output, context, alignment
|
||||
|
|
Loading…
Reference in New Issue