mirror of https://github.com/coqui-ai/TTS.git
Explicit padding for unbalanced padding sizes
parent
fd830c6416
commit
67df385275
|
@ -43,13 +43,15 @@ class LocationSensitiveAttention(nn.Module):
|
|||
self.kernel_size = kernel_size
|
||||
self.filters = filters
|
||||
padding = [(kernel_size - 1) // 2, (kernel_size - 1) // 2]
|
||||
self.loc_conv = nn.Conv1d(
|
||||
2,
|
||||
filters,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
bias=False)
|
||||
self.loc_conv = nn.Sequential(
|
||||
nn.ConstantPad1d(padding, 0),
|
||||
nn.Conv1d(
|
||||
2,
|
||||
filters,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False))
|
||||
self.loc_linear = nn.Linear(filters, attn_dim)
|
||||
self.query_layer = nn.Linear(query_dim, attn_dim, bias=True)
|
||||
self.annot_layer = nn.Linear(annot_dim, attn_dim, bias=True)
|
||||
|
@ -100,8 +102,8 @@ class AttentionRNNCell(nn.Module):
|
|||
annot_dim, rnn_dim, out_dim)
|
||||
else:
|
||||
raise RuntimeError(" Wrong alignment model name: {}. Use\
|
||||
'b' (Bahdanau) or 'ls' (Location Sensitive)."
|
||||
.format(align_model))
|
||||
'b' (Bahdanau) or 'ls' (Location Sensitive).".format(
|
||||
align_model))
|
||||
|
||||
def forward(self, memory, context, rnn_state, annots, atten, mask):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue