mirror of https://github.com/coqui-ai/TTS.git
Enable optional forward attention with transition agent
parent
e2cf35bb10
commit
312a539a0e
|
@ -41,7 +41,9 @@
|
|||
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
||||
"attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
||||
"prenet_type": "bn", // ONLY TACOTRON2 - "original" or "bn".
|
||||
"use_forward_attn": false, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
|
||||
"use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
|
||||
"transition_agent": true, // ONLY TACOTRON2 - enable/disable transition agent of forward attention.
|
||||
"loss_masking": false, // enable / disable loss masking against the sequence padding.
|
||||
|
||||
"batch_size": 16, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||
"eval_batch_size":16,
|
||||
|
|
|
@ -42,6 +42,8 @@
|
|||
"attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
||||
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
|
||||
"use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
|
||||
"transition_agent": true, // ONLY TACOTRON2 - enable/disable transition agent of forward attention.
|
||||
"loss_masking": false, // enable / disable loss masking against the sequence padding.
|
||||
|
||||
"batch_size": 16, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||
"eval_batch_size":16,
|
||||
|
|
|
@ -122,13 +122,15 @@ class LocationLayer(nn.Module):
|
|||
class Attention(nn.Module):
|
||||
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
||||
attention_location_n_filters, attention_location_kernel_size,
|
||||
windowing, norm, forward_attn):
|
||||
windowing, norm, forward_attn, trans_agent):
|
||||
super(Attention, self).__init__()
|
||||
self.query_layer = Linear(
|
||||
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
|
||||
self.inputs_layer = Linear(
|
||||
embedding_dim, attention_dim, bias=False, init_gain='tanh')
|
||||
self.v = Linear(attention_dim, 1, bias=True)
|
||||
if trans_agent:
|
||||
self.ta = nn.Linear(attention_dim + embedding_dim, 1, bias=True)
|
||||
self.location_layer = LocationLayer(attention_location_n_filters,
|
||||
attention_location_kernel_size,
|
||||
attention_dim)
|
||||
|
@ -137,6 +139,7 @@ class Attention(nn.Module):
|
|||
self.win_idx = None
|
||||
self.norm = norm
|
||||
self.forward_attn = forward_attn
|
||||
self.trans_agent = trans_agent
|
||||
|
||||
def init_win_idx(self):
|
||||
self.win_idx = -1
|
||||
|
@ -160,29 +163,46 @@ class Attention(nn.Module):
|
|||
processed_inputs))
|
||||
|
||||
energies = energies.squeeze(-1)
|
||||
return energies
|
||||
return energies, processed_query
|
||||
|
||||
def apply_windowing(self, attention):
|
||||
back_win = self.win_idx - self.win_back
|
||||
front_win = self.win_idx + self.win_front
|
||||
if back_win > 0:
|
||||
attention[:, :back_win] = -float("inf")
|
||||
if front_win < inputs.shape[1]:
|
||||
attention[:, front_win:] = -float("inf")
|
||||
# this is a trick to solve a special problem.
|
||||
# but it does not hurt.
|
||||
if self.win_idx == -1:
|
||||
attention[:, 0] = attention.max()
|
||||
# Update the window
|
||||
self.win_idx = torch.argmax(attention, 1).long()[0].item()
|
||||
return attention
|
||||
|
||||
def apply_forward_attention(self, inputs, alignment, processed_query):
|
||||
# forward attention
|
||||
prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device)
|
||||
self.alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-7) * alignment
|
||||
alpha_norm = self.alpha / self.alpha.sum(dim=1).unsqueeze(1)
|
||||
# compute context
|
||||
context = torch.bmm(alpha_norm.unsqueeze(1), inputs)
|
||||
context = context.squeeze(1)
|
||||
# compute transition agent
|
||||
if self.trans_agent:
|
||||
ta_input = torch.cat([context, processed_query.squeeze(1)], dim=-1)
|
||||
self.u = torch.sigmoid(self.ta(ta_input))
|
||||
return context, alpha_norm, alignment
|
||||
|
||||
def forward(self, attention_hidden_state, inputs, processed_inputs,
|
||||
attention_cat, mask):
|
||||
attention = self.get_attention(
|
||||
attention, processed_query = self.get_attention(
|
||||
attention_hidden_state, processed_inputs, attention_cat)
|
||||
|
||||
if mask is not None:
|
||||
attention.data.masked_fill_(1 - mask, self._mask_value)
|
||||
# Windowing
|
||||
if not self.training and self.windowing:
|
||||
back_win = self.win_idx - self.win_back
|
||||
front_win = self.win_idx + self.win_front
|
||||
if back_win > 0:
|
||||
attention[:, :back_win] = -float("inf")
|
||||
if front_win < inputs.shape[1]:
|
||||
attention[:, front_win:] = -float("inf")
|
||||
# this is a trick to solve a special problem.
|
||||
# but it does not hurt.
|
||||
if self.win_idx == -1:
|
||||
attention[:, 0] = attention.max()
|
||||
# Update the window
|
||||
self.win_idx = torch.argmax(attention, 1).long()[0].item()
|
||||
attention = self.apply_windowing(attention)
|
||||
if self.norm == "softmax":
|
||||
alignment = torch.softmax(attention, dim=-1)
|
||||
elif self.norm == "sigmoid":
|
||||
|
@ -191,14 +211,7 @@ class Attention(nn.Module):
|
|||
else:
|
||||
raise RuntimeError("Unknown value for attention norm type")
|
||||
if self.forward_attn:
|
||||
# forward attention
|
||||
prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device)
|
||||
self.alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-7) * alignment
|
||||
alpha_norm = self.alpha / self.alpha.sum(dim=1).unsqueeze(1)
|
||||
# compute context
|
||||
context = torch.bmm(alpha_norm.unsqueeze(1), inputs)
|
||||
context = context.squeeze(1)
|
||||
return context, alpha_norm, alignment
|
||||
return self.apply_forward_attention(inputs, alignment, processed_query)
|
||||
else:
|
||||
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
||||
context = context.squeeze(1)
|
||||
|
@ -272,7 +285,7 @@ class Encoder(nn.Module):
|
|||
|
||||
# adapted from https://github.com/NVIDIA/tacotron2/
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, forward_attn):
|
||||
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent):
|
||||
super(Decoder, self).__init__()
|
||||
self.mel_channels = inputs_dim
|
||||
self.r = r
|
||||
|
@ -292,7 +305,7 @@ class Decoder(nn.Module):
|
|||
self.attention_rnn_dim)
|
||||
|
||||
self.attention_layer = Attention(self.attention_rnn_dim, in_features,
|
||||
128, 32, 31, attn_win, attn_norm, forward_attn)
|
||||
128, 32, 31, attn_win, attn_norm, forward_attn, trans_agent)
|
||||
|
||||
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
|
||||
self.decoder_rnn_dim, 1)
|
||||
|
|
|
@ -9,7 +9,7 @@ from utils.generic_utils import sequence_mask
|
|||
|
||||
# TODO: match function arguments with tacotron
|
||||
class Tacotron2(nn.Module):
|
||||
def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", forward_attn=False):
|
||||
def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", forward_attn=False, trans_agent=False):
|
||||
super(Tacotron2, self).__init__()
|
||||
self.n_mel_channels = 80
|
||||
self.n_frames_per_step = r
|
||||
|
@ -18,7 +18,7 @@ class Tacotron2(nn.Module):
|
|||
val = sqrt(3.0) * std # uniform bounds for std
|
||||
self.embedding.weight.data.uniform_(-val, val)
|
||||
self.encoder = Encoder(512)
|
||||
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, forward_attn)
|
||||
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent)
|
||||
self.postnet = Postnet(self.n_mel_channels)
|
||||
|
||||
def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):
|
||||
|
|
Loading…
Reference in New Issue