add monotonic dynamic convolution attention

pull/10/head
erogol 2020-12-12 20:15:31 +01:00
parent 18392bc13a
commit 070146e143
4 changed files with 101 additions and 4 deletions

View File

@ -99,7 +99,7 @@
"prenet_dropout": false, // enable/disable dropout at prenet.
// TACOTRON ATTENTION
"attention_type": "original", // 'original' or 'graves'
"attention_type": "original", // 'original' , 'graves', 'dynamic_convolution'
"attention_heads": 4, // number of attention heads (only for 'graves')
"attention_norm": "sigmoid", // softmax or sigmoid.
"windowing": false, // Enables attention windowing. Used only in eval mode.

View File

@ -1,6 +1,7 @@
import torch
from torch import nn
from torch.nn import functional as F
from scipy.stats import betabinom
class Linear(nn.Module):
@ -371,6 +372,90 @@ class OriginalAttention(nn.Module):
self.u = torch.sigmoid(self.ta(ta_input))
return context
class MonotonicDynamicConvolutionAttention(nn.Module):
"""Dynamic convolution attention from
https://arxiv.org/pdf/1910.10288.pdf
"""
def __init__(
self,
query_dim,
embedding_dim,
attention_dim,
static_filter_dim,
static_kernel_size,
dynamic_filter_dim,
dynamic_kernel_size,
prior_filter_len=11,
alpha=0.1,
beta=0.9,
):
super().__init__()
self._mask_value = 1e-8
self.dynamic_filter_dim = dynamic_filter_dim
self.dynamic_kernel_size = dynamic_kernel_size
self.prior_filter_len = prior_filter_len
self.attention_weights = None
# setup key and query layers
self.query_layer = nn.Linear(query_dim, attention_dim)
self.key_layer = nn.Linear(
attention_dim, dynamic_filter_dim * dynamic_kernel_size, bias=False
)
self.static_filter_conv = nn.Conv1d(
1,
static_filter_dim,
static_kernel_size,
padding=(static_kernel_size - 1) // 2,
bias=False,
)
self.static_filter_layer = nn.Linear(static_filter_dim, attention_dim, bias=False)
self.dynamic_filter_layer = nn.Linear(dynamic_filter_dim, attention_dim)
self.v = nn.Linear(attention_dim, 1, bias=False)
prior = betabinom.pmf(range(prior_filter_len), prior_filter_len - 1,
alpha, beta)
self.register_buffer("prior", torch.FloatTensor(prior).flip(0))
# pylint: disable=unused-argument
def forward(self, query, inputs, processed_inputs, mask):
# compute prior filters
prior_filter = F.conv1d(
F.pad(self.attention_weights.unsqueeze(1),
(self.prior_filter_len - 1, 0)), self.prior.view(1, 1, -1))
prior_filter = torch.log(prior_filter.clamp_min_(1e-6)).squeeze(1)
G = self.key_layer(torch.tanh(self.query_layer(query)))
# compute dynamic filters
dynamic_filter = F.conv1d(
self.attention_weights.unsqueeze(0),
G.view(-1, 1, self.dynamic_kernel_size),
padding=(self.dynamic_kernel_size - 1) // 2,
groups=query.size(0),
)
dynamic_filter = dynamic_filter.view(query.size(0), self.dynamic_filter_dim, -1).transpose(1, 2)
# compute static filters
static_filter = self.static_filter_conv(self.attention_weights.unsqueeze(1)).transpose(1, 2)
alignment = self.v(
torch.tanh(
self.static_filter_layer(static_filter) +
self.dynamic_filter_layer(dynamic_filter))).squeeze(-1) + prior_filter
# compute attention weights
attention_weights = F.softmax(alignment, dim=-1)
# apply masking
if mask is not None:
attention_weights.data.masked_fill_(~mask, self._mask_value)
self.attention_weights = attention_weights
# compute context
context = torch.bmm(attention_weights.unsqueeze(1), inputs).squeeze(1)
return context
def preprocess_inputs(self, inputs):
return None
def init_states(self, inputs):
B = inputs.size(0)
T = inputs.size(1)
self.attention_weights = torch.zeros([B, T], device=inputs.device)
self.attention_weights[:, 0] = 1.
def init_attn(attn_type, query_dim, embedding_dim, attention_dim,
location_attention, attention_location_n_filters,
@ -385,5 +470,17 @@ def init_attn(attn_type, query_dim, embedding_dim, attention_dim,
forward_attn_mask)
if attn_type == "graves":
return GravesAttention(query_dim, attn_K)
if attn_type == "dynamic_convolution":
return MonotonicDynamicConvolutionAttention(query_dim,
embedding_dim,
attention_dim,
static_filter_dim=8,
static_kernel_size=21,
dynamic_filter_dim=8,
dynamic_kernel_size=21,
prior_filter_len=11,
alpha=0.1,
beta=0.9)
raise RuntimeError(
" [!] Given Attention Type '{attn_type}' is not exist.")
" [!] Given Attention Type '{attn_type}' is not exist.")

View File

@ -211,7 +211,7 @@ def check_config_tts(c):
check_argument('prenet_dropout', c, restricted=is_tacotron(c), val_type=bool)
# attention
check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original'])
check_argument('attention_type', c, restricted=is_tacotron(c), val_type=str, enum_list=['graves', 'original', 'dynamic_convolution'])
check_argument('attention_heads', c, restricted=is_tacotron(c), val_type=int)
check_argument('attention_norm', c, restricted=is_tacotron(c), val_type=str, enum_list=['sigmoid', 'softmax'])
check_argument('windowing', c, restricted=is_tacotron(c), val_type=bool)

View File

@ -100,7 +100,7 @@
"prenet_dropout": false, // enable/disable dropout at prenet.
// TACOTRON ATTENTION
"attention_type": "original", // 'original' or 'graves'
"attention_type": "original", // 'original' , 'graves', 'dynamic_convolution'
"attention_heads": 4, // number of attention heads (only for 'graves')
"attention_norm": "sigmoid", // softmax or sigmoid.
"windowing": false, // Enables attention windowing. Used only in eval mode.