mirror of https://github.com/coqui-ai/TTS.git
loca sens attn fix
parent
4e6596a8e1
commit
4ef3ecf37f
|
@ -39,7 +39,7 @@ class LocationSensitiveAttention(nn.Module):
|
|||
self.kernel_size = kernel_size
|
||||
self.filters = filters
|
||||
padding = int((kernel_size - 1) / 2)
|
||||
self.loc_conv = nn.Conv1d(1, filters,
|
||||
self.loc_conv = nn.Conv1d(2, filters,
|
||||
kernel_size=kernel_size, stride=1,
|
||||
padding=padding, bias=False)
|
||||
self.loc_linear = nn.Linear(filters, attn_dim)
|
||||
|
@ -77,15 +77,15 @@ class AttentionRNNCell(nn.Module):
|
|||
out_dim (int): context vector feature dimension.
|
||||
rnn_dim (int): rnn hidden state dimension.
|
||||
annot_dim (int): annotation vector feature dimension.
|
||||
memory_dim (int): memory vector (decoder autogression) feature dimension.
|
||||
memory_dim (int): memory vector (decoder output) feature dimension.
|
||||
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
|
||||
"""
|
||||
super(AttentionRNNCell, self).__init__()
|
||||
self.align_model = align_model
|
||||
self.rnn_cell = nn.GRUCell(out_dim + memory_dim, rnn_dim)
|
||||
self.rnn_cell = nn.GRUCell(annot_dim + memory_dim, rnn_dim)
|
||||
# pick bahdanau or location sensitive attention
|
||||
if align_model == 'b':
|
||||
self.alignment_model = BahdanauAttention(annot_dim, out_dim, out_dim)
|
||||
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim, out_dim)
|
||||
if align_model == 'ls':
|
||||
self.alignment_model = LocationSensitiveAttention(annot_dim, rnn_dim, out_dim)
|
||||
else:
|
||||
|
|
|
@ -198,6 +198,7 @@ class Decoder(nn.Module):
|
|||
def __init__(self, in_features, memory_dim, r):
|
||||
super(Decoder, self).__init__()
|
||||
self.r = r
|
||||
self.in_features = in_features
|
||||
self.max_decoder_steps = 200
|
||||
self.memory_dim = memory_dim
|
||||
# memory -> |Prenet| -> processed_memory
|
||||
|
@ -249,7 +250,7 @@ class Decoder(nn.Module):
|
|||
attention_rnn_hidden = inputs.data.new(B, 256).zero_()
|
||||
decoder_rnn_hiddens = [inputs.data.new(B, 256).zero_()
|
||||
for _ in range(len(self.decoder_rnns))]
|
||||
current_context_vec = inputs.data.new(B, 128).zero_()
|
||||
current_context_vec = inputs.data.new(B, self.in_features).zero_()
|
||||
stopnet_rnn_hidden = inputs.data.new(B, self.r * self.memory_dim).zero_()
|
||||
# attention states
|
||||
attention = inputs.data.new(B, T).zero_()
|
||||
|
|
Loading…
Reference in New Issue