mirror of https://github.com/coqui-ai/TTS.git
small refactoring
parent
fd081c49b7
commit
0f0ec679ec
|
@ -84,15 +84,17 @@ class Prenet(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class LocationLayer(nn.Module):
|
class LocationLayer(nn.Module):
|
||||||
def __init__(self, attention_n_filters, attention_kernel_size,
|
def __init__(self,
|
||||||
attention_dim):
|
attention_dim,
|
||||||
|
attention_n_filters=32,
|
||||||
|
attention_kernel_size=31):
|
||||||
super(LocationLayer, self).__init__()
|
super(LocationLayer, self).__init__()
|
||||||
self.location_conv = nn.Conv1d(
|
self.location_conv = nn.Conv1d(
|
||||||
in_channels=2,
|
in_channels=2,
|
||||||
out_channels=attention_n_filters,
|
out_channels=attention_n_filters,
|
||||||
kernel_size=31,
|
kernel_size=attention_kernel_size,
|
||||||
stride=1,
|
stride=1,
|
||||||
padding=(31 - 1) // 2,
|
padding=(attention_kernel_size - 1) // 2,
|
||||||
bias=False)
|
bias=False)
|
||||||
self.location_dense = Linear(
|
self.location_dense = Linear(
|
||||||
attention_n_filters, attention_dim, bias=False, init_gain='tanh')
|
attention_n_filters, attention_dim, bias=False, init_gain='tanh')
|
||||||
|
@ -120,8 +122,10 @@ class Attention(nn.Module):
|
||||||
attention_rnn_dim + embedding_dim, 1, bias=True)
|
attention_rnn_dim + embedding_dim, 1, bias=True)
|
||||||
if location_attention:
|
if location_attention:
|
||||||
self.location_layer = LocationLayer(
|
self.location_layer = LocationLayer(
|
||||||
attention_location_n_filters, attention_location_kernel_size,
|
attention_dim,
|
||||||
attention_dim)
|
attention_location_n_filters,
|
||||||
|
attention_location_kernel_size,
|
||||||
|
)
|
||||||
self._mask_value = -float("inf")
|
self._mask_value = -float("inf")
|
||||||
self.windowing = windowing
|
self.windowing = windowing
|
||||||
self.win_idx = None
|
self.win_idx = None
|
||||||
|
@ -203,14 +207,18 @@ class Attention(nn.Module):
|
||||||
# compute transition potentials
|
# compute transition potentials
|
||||||
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
|
alpha = (((1 - self.u) * self.alpha.clone().to(inputs.device) +
|
||||||
self.u * prev_alpha) + 1e-8) * alignment
|
self.u * prev_alpha) + 1e-8) * alignment
|
||||||
# force incremental alignment - TODO: make configurable
|
# force incremental alignment
|
||||||
if not self.training and self.forward_attn_mask:
|
if not self.training and self.forward_attn_mask:
|
||||||
_, n = prev_alpha.max(1)
|
_, n = prev_alpha.max(1)
|
||||||
val, n2 = alpha.max(1)
|
val, n2 = alpha.max(1)
|
||||||
for b in range(alignment.shape[0]):
|
for b in range(alignment.shape[0]):
|
||||||
alpha[b, n[b] + 3:] = 0
|
alpha[b, n[b] + 3:] = 0
|
||||||
alpha[b, :(n[b] - 1)] = 0 # ignore all previous states to prevent repetition.
|
alpha[b, :(
|
||||||
alpha[b, (n[b] - 2)] = 0.01 * val[b] # smoothing factor for the prev step
|
n[b] - 1
|
||||||
|
)] = 0 # ignore all previous states to prevent repetition.
|
||||||
|
alpha[b,
|
||||||
|
(n[b] - 2
|
||||||
|
)] = 0.01 * val[b] # smoothing factor for the prev step
|
||||||
# compute attention weights
|
# compute attention weights
|
||||||
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
|
self.alpha = alpha / alpha.sum(dim=1).unsqueeze(1)
|
||||||
# compute context
|
# compute context
|
||||||
|
@ -240,7 +248,8 @@ class Attention(nn.Module):
|
||||||
alignment = torch.softmax(attention, dim=-1)
|
alignment = torch.softmax(attention, dim=-1)
|
||||||
elif self.norm == "sigmoid":
|
elif self.norm == "sigmoid":
|
||||||
alignment = torch.sigmoid(attention) / torch.sigmoid(
|
alignment = torch.sigmoid(attention) / torch.sigmoid(
|
||||||
attention).sum(dim=1).unsqueeze(1)
|
attention).sum(
|
||||||
|
dim=1, keepdim=True)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unknown value for attention norm type")
|
raise RuntimeError("Unknown value for attention norm type")
|
||||||
if self.location_attention:
|
if self.location_attention:
|
||||||
|
|
|
@ -268,6 +268,7 @@ class Decoder(nn.Module):
|
||||||
memory_dim (int): memory vector (prev. time-step output) sample size.
|
memory_dim (int): memory vector (prev. time-step output) sample size.
|
||||||
r (int): number of outputs per time step.
|
r (int): number of outputs per time step.
|
||||||
memory_size (int): size of the past window. if <= 0 memory_size = r
|
memory_size (int): size of the past window. if <= 0 memory_size = r
|
||||||
|
TODO: arguments
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing,
|
def __init__(self, in_features, memory_dim, r, memory_size, attn_windowing,
|
||||||
|
|
Loading…
Reference in New Issue