mirror of https://github.com/coqui-ai/TTS.git
small refactoring
parent
fd081c49b7
commit
0f0ec679ec
|
@ -84,15 +84,17 @@ class Prenet(nn.Module):
|
|||
|
||||
|
||||
class LocationLayer(nn.Module):
|
||||
def __init__(self, attention_n_filters, attention_kernel_size,
|
||||
attention_dim):
|
||||
def __init__(self,
|
||||
attention_dim,
|
||||
attention_n_filters=32,
|
||||
attention_kernel_size=31):
|
||||
super(LocationLayer, self).__init__()
|
||||
self.location_conv = nn.Conv1d(
|
||||
in_channels=2,
|
||||
out_channels=attention_n_filters,
|
||||
kernel_size=31,
|
||||
kernel_size=attention_kernel_size,
|
||||
stride=1,
|
||||
padding=(31 - 1) // 2,
|
||||
padding=(attention_kernel_size - 1) // 2,
|
||||
bias=False)
|
||||
self.location_dense = Linear(
|
||||
attention_n_filters, attention_dim, bias=False, init_gain='tanh')
|
||||
|
@ -120,8 +122,10 @@ class Attention(nn.Module):
|
|||
attention_rnn_dim + embedding_dim, 1, bias=True)
|
||||
if location_attention:
|
||||
self.location_layer = LocationLayer(
|
||||
attention_location_n_filters, attention_location_kernel_size,
|
||||
attention_dim)
|
||||
attention_dim,
|
||||
attention_location_n_filters,
|
||||
attention_location_kernel_size,
|
||||
)
|
||||
self._mask_value = -float("inf")
|
||||
self.windowing = windowing
|
||||
self.win_idx = None
|
||||
|
@ -203,14 +207,18 @@ class Attention(nn.Module):
|
|||
# compute transition potentials
|
||||
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
|
||||
self.u * prev_alpha) + 1e-8) * alignment
|
||||
# force incremental alignment - TODO: make configurable
|
||||
# force incremental alignment
|
||||
if not self.training and self.forward_attn_mask:
|
||||
_, n = prev_alpha.max(1)
|
||||
val, n2 = alpha.max(1)
|
||||
for b in range(alignment.shape[0]):
|
||||
alpha[b, n[b] + 3:] = 0
|
||||
alpha[b, :(n[b] - 1)] = 0 # ignore all previous states to prevent repetition.
|
||||
alpha[b, (n[b] - 2)] = 0.01 * val[b] # smoothing factor for the prev step
|
||||
alpha[b, :(
|
||||
n[b] - 1
|
||||
)] = 0 # ignore all previous states to prevent repetition.
|
||||
alpha[b,
|
||||
(n[b] - 2
|
||||
)] = 0.01 * val[b] # smoothing factor for the prev step
|
||||
# compute attention weights
|
||||
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
|
||||
# compute context
|
||||
|
@ -240,7 +248,8 @@ class Attention(nn.Module):
|
|||
alignment = torch.softmax(attention, dim=-1)
|
||||
elif self.norm == "sigmoid":
|
||||
alignment = torch.sigmoid(attention) / torch.sigmoid(
|
||||
attention).sum(dim=1).unsqueeze(1)
|
||||
attention).sum(
|
||||
dim=1, keepdim=True)
|
||||
else:
|
||||
raise RuntimeError("Unknown value for attention norm type")
|
||||
if self.location_attention:
|
||||
|
|
|
@ -268,6 +268,7 @@ class Decoder(nn.Module):
|
|||
memory_dim (int): memory vector (prev. time-step output) sample size.
|
||||
r (int): number of outputs per time step.
|
||||
memory_size (int): size of the past window. if <= 0 memory_size = r
|
||||
TODO: arguments
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing,
|
||||
|
|
Loading…
Reference in New Issue