mirror of https://github.com/coqui-ai/TTS.git
Enable optional forward attention with transition agent
parent
e2cf35bb10
commit
312a539a0e
|
@ -41,7 +41,9 @@
|
||||||
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
"memory_size": 5, // ONLY TACOTRON - memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
|
||||||
"attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
"attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
||||||
"prenet_type": "bn", // ONLY TACOTRON2 - "original" or "bn".
|
"prenet_type": "bn", // ONLY TACOTRON2 - "original" or "bn".
|
||||||
"use_forward_attn": false, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
|
"use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
|
||||||
|
"transition_agent": true, // ONLY TACOTRON2 - enable/disable transition agent of forward attention.
|
||||||
|
"loss_masking": false, // enable / disable loss masking against the sequence padding.
|
||||||
|
|
||||||
"batch_size": 16, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
"batch_size": 16, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||||
"eval_batch_size":16,
|
"eval_batch_size":16,
|
||||||
|
|
|
@ -42,6 +42,8 @@
|
||||||
"attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
"attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
|
||||||
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
|
"prenet_type": "original", // ONLY TACOTRON2 - "original" or "bn".
|
||||||
"use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
|
"use_forward_attn": true, // ONLY TACOTRON2 - if it uses forward attention. In general, it aligns faster.
|
||||||
|
"transition_agent": true, // ONLY TACOTRON2 - enable/disable transition agent of forward attention.
|
||||||
|
"loss_masking": false, // enable / disable loss masking against the sequence padding.
|
||||||
|
|
||||||
"batch_size": 16, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
"batch_size": 16, // Batch size for training. Lower values than 32 might cause hard to learn attention.
|
||||||
"eval_batch_size":16,
|
"eval_batch_size":16,
|
||||||
|
|
|
@ -122,13 +122,15 @@ class LocationLayer(nn.Module):
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
||||||
attention_location_n_filters, attention_location_kernel_size,
|
attention_location_n_filters, attention_location_kernel_size,
|
||||||
windowing, norm, forward_attn):
|
windowing, norm, forward_attn, trans_agent):
|
||||||
super(Attention, self).__init__()
|
super(Attention, self).__init__()
|
||||||
self.query_layer = Linear(
|
self.query_layer = Linear(
|
||||||
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
|
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
|
||||||
self.inputs_layer = Linear(
|
self.inputs_layer = Linear(
|
||||||
embedding_dim, attention_dim, bias=False, init_gain='tanh')
|
embedding_dim, attention_dim, bias=False, init_gain='tanh')
|
||||||
self.v = Linear(attention_dim, 1, bias=True)
|
self.v = Linear(attention_dim, 1, bias=True)
|
||||||
|
if trans_agent:
|
||||||
|
self.ta = nn.Linear(attention_dim + embedding_dim, 1, bias=True)
|
||||||
self.location_layer = LocationLayer(attention_location_n_filters,
|
self.location_layer = LocationLayer(attention_location_n_filters,
|
||||||
attention_location_kernel_size,
|
attention_location_kernel_size,
|
||||||
attention_dim)
|
attention_dim)
|
||||||
|
@ -137,6 +139,7 @@ class Attention(nn.Module):
|
||||||
self.win_idx = None
|
self.win_idx = None
|
||||||
self.norm = norm
|
self.norm = norm
|
||||||
self.forward_attn = forward_attn
|
self.forward_attn = forward_attn
|
||||||
|
self.trans_agent = trans_agent
|
||||||
|
|
||||||
def init_win_idx(self):
|
def init_win_idx(self):
|
||||||
self.win_idx = -1
|
self.win_idx = -1
|
||||||
|
@ -160,29 +163,46 @@ class Attention(nn.Module):
|
||||||
processed_inputs))
|
processed_inputs))
|
||||||
|
|
||||||
energies = energies.squeeze(-1)
|
energies = energies.squeeze(-1)
|
||||||
return energies
|
return energies, processed_query
|
||||||
|
|
||||||
|
def apply_windowing(self, attention):
|
||||||
|
back_win = self.win_idx - self.win_back
|
||||||
|
front_win = self.win_idx + self.win_front
|
||||||
|
if back_win > 0:
|
||||||
|
attention[:, :back_win] = -float("inf")
|
||||||
|
if front_win < inputs.shape[1]:
|
||||||
|
attention[:, front_win:] = -float("inf")
|
||||||
|
# this is a trick to solve a special problem.
|
||||||
|
# but it does not hurt.
|
||||||
|
if self.win_idx == -1:
|
||||||
|
attention[:, 0] = attention.max()
|
||||||
|
# Update the window
|
||||||
|
self.win_idx = torch.argmax(attention, 1).long()[0].item()
|
||||||
|
return attention
|
||||||
|
|
||||||
|
def apply_forward_attention(self, inputs, alignment, processed_query):
|
||||||
|
# forward attention
|
||||||
|
prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device)
|
||||||
|
self.alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-7) * alignment
|
||||||
|
alpha_norm = self.alpha / self.alpha.sum(dim=1).unsqueeze(1)
|
||||||
|
# compute context
|
||||||
|
context = torch.bmm(alpha_norm.unsqueeze(1), inputs)
|
||||||
|
context = context.squeeze(1)
|
||||||
|
# compute transition agent
|
||||||
|
if self.trans_agent:
|
||||||
|
ta_input = torch.cat([context, processed_query.squeeze(1)], dim=-1)
|
||||||
|
self.u = torch.sigmoid(self.ta(ta_input))
|
||||||
|
return context, alpha_norm, alignment
|
||||||
|
|
||||||
def forward(self, attention_hidden_state, inputs, processed_inputs,
|
def forward(self, attention_hidden_state, inputs, processed_inputs,
|
||||||
attention_cat, mask):
|
attention_cat, mask):
|
||||||
attention = self.get_attention(
|
attention, processed_query = self.get_attention(
|
||||||
attention_hidden_state, processed_inputs, attention_cat)
|
attention_hidden_state, processed_inputs, attention_cat)
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
attention.data.masked_fill_(1 - mask, self._mask_value)
|
attention.data.masked_fill_(1 - mask, self._mask_value)
|
||||||
# Windowing
|
|
||||||
if not self.training and self.windowing:
|
if not self.training and self.windowing:
|
||||||
back_win = self.win_idx - self.win_back
|
attention = self.apply_windowing(attention)
|
||||||
front_win = self.win_idx + self.win_front
|
|
||||||
if back_win > 0:
|
|
||||||
attention[:, :back_win] = -float("inf")
|
|
||||||
if front_win < inputs.shape[1]:
|
|
||||||
attention[:, front_win:] = -float("inf")
|
|
||||||
# this is a trick to solve a special problem.
|
|
||||||
# but it does not hurt.
|
|
||||||
if self.win_idx == -1:
|
|
||||||
attention[:, 0] = attention.max()
|
|
||||||
# Update the window
|
|
||||||
self.win_idx = torch.argmax(attention, 1).long()[0].item()
|
|
||||||
if self.norm == "softmax":
|
if self.norm == "softmax":
|
||||||
alignment = torch.softmax(attention, dim=-1)
|
alignment = torch.softmax(attention, dim=-1)
|
||||||
elif self.norm == "sigmoid":
|
elif self.norm == "sigmoid":
|
||||||
|
@ -191,14 +211,7 @@ class Attention(nn.Module):
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unknown value for attention norm type")
|
raise RuntimeError("Unknown value for attention norm type")
|
||||||
if self.forward_attn:
|
if self.forward_attn:
|
||||||
# forward attention
|
return self.apply_forward_attention(inputs, alignment, processed_query)
|
||||||
prev_alpha = F.pad(self.alpha[:, :-1].clone(), (1, 0, 0, 0)).to(inputs.device)
|
|
||||||
self.alpha = (((1-self.u) * self.alpha.clone().to(inputs.device) + self.u * prev_alpha) + 1e-7) * alignment
|
|
||||||
alpha_norm = self.alpha / self.alpha.sum(dim=1).unsqueeze(1)
|
|
||||||
# compute context
|
|
||||||
context = torch.bmm(alpha_norm.unsqueeze(1), inputs)
|
|
||||||
context = context.squeeze(1)
|
|
||||||
return context, alpha_norm, alignment
|
|
||||||
else:
|
else:
|
||||||
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
context = torch.bmm(alignment.unsqueeze(1), inputs)
|
||||||
context = context.squeeze(1)
|
context = context.squeeze(1)
|
||||||
|
@ -272,7 +285,7 @@ class Encoder(nn.Module):
|
||||||
|
|
||||||
# adapted from https://github.com/NVIDIA/tacotron2/
|
# adapted from https://github.com/NVIDIA/tacotron2/
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, forward_attn):
|
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent):
|
||||||
super(Decoder, self).__init__()
|
super(Decoder, self).__init__()
|
||||||
self.mel_channels = inputs_dim
|
self.mel_channels = inputs_dim
|
||||||
self.r = r
|
self.r = r
|
||||||
|
@ -292,7 +305,7 @@ class Decoder(nn.Module):
|
||||||
self.attention_rnn_dim)
|
self.attention_rnn_dim)
|
||||||
|
|
||||||
self.attention_layer = Attention(self.attention_rnn_dim, in_features,
|
self.attention_layer = Attention(self.attention_rnn_dim, in_features,
|
||||||
128, 32, 31, attn_win, attn_norm, forward_attn)
|
128, 32, 31, attn_win, attn_norm, forward_attn, trans_agent)
|
||||||
|
|
||||||
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
|
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
|
||||||
self.decoder_rnn_dim, 1)
|
self.decoder_rnn_dim, 1)
|
||||||
|
|
|
@ -9,7 +9,7 @@ from utils.generic_utils import sequence_mask
|
||||||
|
|
||||||
# TODO: match function arguments with tacotron
|
# TODO: match function arguments with tacotron
|
||||||
class Tacotron2(nn.Module):
|
class Tacotron2(nn.Module):
|
||||||
def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", forward_attn=False):
|
def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax", prenet_type="original", forward_attn=False, trans_agent=False):
|
||||||
super(Tacotron2, self).__init__()
|
super(Tacotron2, self).__init__()
|
||||||
self.n_mel_channels = 80
|
self.n_mel_channels = 80
|
||||||
self.n_frames_per_step = r
|
self.n_frames_per_step = r
|
||||||
|
@ -18,7 +18,7 @@ class Tacotron2(nn.Module):
|
||||||
val = sqrt(3.0) * std # uniform bounds for std
|
val = sqrt(3.0) * std # uniform bounds for std
|
||||||
self.embedding.weight.data.uniform_(-val, val)
|
self.embedding.weight.data.uniform_(-val, val)
|
||||||
self.encoder = Encoder(512)
|
self.encoder = Encoder(512)
|
||||||
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, forward_attn)
|
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm, prenet_type, forward_attn, trans_agent)
|
||||||
self.postnet = Postnet(self.n_mel_channels)
|
self.postnet = Postnet(self.n_mel_channels)
|
||||||
|
|
||||||
def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):
|
def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):
|
||||||
|
|
Loading…
Reference in New Issue