mirror of https://github.com/coqui-ai/TTS.git
add monotonic dynamic convolution attention
parent
18392bc13a
commit
070146e143
|
@ -99,7 +99,7 @@
|
|||
"prenet_dropout": false, // enable/disable dropout at prenet.
|
||||
|
||||
// TACOTRON ATTENTION
|
||||
"attention_type": "original", // 'original' or 'graves'
|
||||
"attention_type": "original", // 'original' , 'graves', 'dynamic_convolution'
|
||||
"attention_heads": 4, // number of attention heads (only for 'graves')
|
||||
"attention_norm": "sigmoid", // softmax or sigmoid.
|
||||
"windowing": false, // Enables attention windowing. Used only in eval mode.
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from scipy.stats import betabinom
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
|
@ -371,6 +372,90 @@ class OriginalAttention(nn.Module):
|
|||
self.u = torch.sigmoid(self.ta(ta_input))
|
||||
return context
|
||||
|
||||
class MonotonicDynamicConvolutionAttention(nn.Module):
|
||||
"""Dynamic convolution attention from
|
||||
https://arxiv.org/pdf/1910.10288.pdf
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
query_dim,
|
||||
embedding_dim,
|
||||
attention_dim,
|
||||
static_filter_dim,
|
||||
static_kernel_size,
|
||||
dynamic_filter_dim,
|
||||
dynamic_kernel_size,
|
||||
prior_filter_len=11,
|
||||
alpha=0.1,
|
||||
beta=0.9,
|
||||
):
|
||||
super().__init__()
|
||||
self._mask_value = 1e-8
|
||||
self.dynamic_filter_dim = dynamic_filter_dim
|
||||
self.dynamic_kernel_size = dynamic_kernel_size
|
||||
self.prior_filter_len = prior_filter_len
|
||||
self.attention_weights = None
|
||||
# setup key and query layers
|
||||
self.query_layer = nn.Linear(query_dim, attention_dim)
|
||||
self.key_layer = nn.Linear(
|
||||
attention_dim, dynamic_filter_dim * dynamic_kernel_size, bias=False
|
||||
)
|
||||
self.static_filter_conv = nn.Conv1d(
|
||||
1,
|
||||
static_filter_dim,
|
||||
static_kernel_size,
|
||||
padding=(static_kernel_size - 1) // 2,
|
||||
bias=False,
|
||||
)
|
||||
self.static_filter_layer = nn.Linear(static_filter_dim, attention_dim, bias=False)
|
||||
self.dynamic_filter_layer = nn.Linear(dynamic_filter_dim, attention_dim)
|
||||
self.v = nn.Linear(attention_dim, 1, bias=False)
|
||||
|
||||
prior = betabinom.pmf(range(prior_filter_len), prior_filter_len - 1,
|
||||
alpha, beta)
|
||||
self.register_buffer("prior", torch.FloatTensor(prior).flip(0))
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def forward(self, query, inputs, processed_inputs, mask):
|
||||
# compute prior filters
|
||||
prior_filter = F.conv1d(
|
||||
F.pad(self.attention_weights.unsqueeze(1),
|
||||
(self.prior_filter_len - 1, 0)), self.prior.view(1, 1, -1))
|
||||
prior_filter = torch.log(prior_filter.clamp_min_(1e-6)).squeeze(1)
|
||||
G = self.key_layer(torch.tanh(self.query_layer(query)))
|
||||
# compute dynamic filters
|
||||
dynamic_filter = F.conv1d(
|
||||
self.attention_weights.unsqueeze(0),
|
||||
G.view(-1, 1, self.dynamic_kernel_size),
|
||||
padding=(self.dynamic_kernel_size - 1) // 2,
|
||||
groups=query.size(0),
|
||||
)
|
||||
dynamic_filter = dynamic_filter.view(query.size(0), self.dynamic_filter_dim, -1).transpose(1, 2)
|
||||
# compute static filters
|
||||
static_filter = self.static_filter_conv(self.attention_weights.unsqueeze(1)).transpose(1, 2)
|
||||
alignment = self.v(
|
||||
torch.tanh(
|
||||
self.static_filter_layer(static_filter) +
|
||||
self.dynamic_filter_layer(dynamic_filter))).squeeze(-1) + prior_filter
|
||||
# compute attention weights
|
||||
attention_weights = F.softmax(alignment, dim=-1)
|
||||
# apply masking
|
||||
if mask is not None:
|
||||
attention_weights.data.masked_fill_(~mask, self._mask_value)
|
||||
self.attention_weights = attention_weights
|
||||
# compute context
|
||||
context = torch.bmm(attention_weights.unsqueeze(1), inputs).squeeze(1)
|
||||
return context
|
||||
|
||||
def preprocess_inputs(self, inputs):
|
||||
return None
|
||||
|
||||
def init_states(self, inputs):
|
||||
B = inputs.size(0)
|
||||
T = inputs.size(1)
|
||||
self.attention_weights = torch.zeros([B, T], device=inputs.device)
|
||||
self.attention_weights[:, 0] = 1.
|
||||
|
||||
|
||||
def init_attn(attn_type, query_dim, embedding_dim, attention_dim,
|
||||
location_attention, attention_location_n_filters,
|
||||
|
@ -385,5 +470,17 @@ def init_attn(attn_type, query_dim, embedding_dim, attention_dim,
|
|||
forward_attn_mask)
|
||||
if attn_type == "graves":
|
||||
return GravesAttention(query_dim, attn_K)
|
||||
if attn_type == "dynamic_convolution":
|
||||
return MonotonicDynamicConvolutionAttention(query_dim,
|
||||
embedding_dim,
|
||||
attention_dim,
|
||||
static_filter_dim=8,
|
||||
static_kernel_size=21,
|
||||
dynamic_filter_dim=8,
|
||||
dynamic_kernel_size=21,
|
||||
prior_filter_len=11,
|
||||
alpha=0.1,
|
||||
beta=0.9)
|
||||
|
||||
raise RuntimeError(
|
||||
" [!] Given Attention Type '{attn_type}' is not exist.")
|
||||
" [!] Given Attention Type '{attn_type}' is not exist.")
|
|
@ -211,7 +211,7 @@ def check_config_tts(c):
|
|||
check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool)
|
||||
|
||||
# attention
|
||||
check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original'])
|
||||
check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original', 'dynamic_convolution'])
|
||||
check_argument('attention_heads', c, restricted=is_tacotron(c), val_type=int)
|
||||
check_argument('attention_norm', c, restricted=is_tacotron(c), val_type=str, enum_list=['sigmoid', 'softmax'])
|
||||
check_argument('windowing', c, restricted=is_tacotron(c), val_type=bool)
|
||||
|
|
|
@ -100,7 +100,7 @@
|
|||
"prenet_dropout": false, // enable/disable dropout at prenet.
|
||||
|
||||
// TACOTRON ATTENTION
|
||||
"attention_type": "original", // 'original' or 'graves'
|
||||
"attention_type": "original", // 'original' , 'graves', 'dynamic_convolution'
|
||||
"attention_heads": 4, // number of attention heads (only for 'graves')
|
||||
"attention_norm": "sigmoid", // softmax or sigmoid.
|
||||
"windowing": false, // Enables attention windowing. Used only in eval mode.
|
||||
|
|
Loading…
Reference in New Issue