Set attention norm method by config.json

pull/10/head
Eren Golge 2019-03-26 00:48:12 +01:00
parent 786510cd6a
commit 0a92c6d5a7
8 changed files with 31 additions and 16 deletions

View File

@ -39,6 +39,7 @@
"warmup_steps": 4000, // Noam decay steps to increase the learning rate from 0 to "lr"
"windowing": false, // Enables attention windowing. Used only in eval mode.
"memory_size": 5, // TO BE IMPLEMENTED -- memory queue size used to queue network predictions to feed autoregressive connection. Useful if r < 5.
"attention_norm": "softmax", // softmax or sigmoid. Suggested to use softmax for Tacotron2 and sigmoid for Tacotron.
"batch_size": 16, // Batch size for training. Lower values than 32 might cause hard to learn attention.
"eval_batch_size":16,

View File

@ -100,7 +100,7 @@ class LocationSensitiveAttention(nn.Module):
class AttentionRNNCell(nn.Module):
def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model, windowing=False):
def __init__(self, out_dim, rnn_dim, annot_dim, memory_dim, align_model, windowing=False, norm="sigmoid"):
r"""
General Attention RNN wrapper
@ -112,6 +112,7 @@ class AttentionRNNCell(nn.Module):
align_model (str): 'b' for Bahdanau, 'ls' Location Sensitive alignment.
windowing (bool): attention windowing forcing monotonic attention.
It is only active in eval mode.
norm (str): norm method to compute alignment weights.
"""
super(AttentionRNNCell, self).__init__()
self.align_model = align_model
@ -121,7 +122,7 @@ class AttentionRNNCell(nn.Module):
self.win_back = 3
self.win_front = 6
self.win_idx = None
# pick bahdanau or location sensitive attention
self.norm = norm
if align_model == 'b':
self.alignment_model = BahdanauAttention(annot_dim, rnn_dim,
out_dim)
@ -164,7 +165,12 @@ class AttentionRNNCell(nn.Module):
alignment[:, front_win:] = -float("inf")
# Update the window
self.win_idx = torch.argmax(alignment,1).long()[0].item()
alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
if self.norm == "softmax":
alignment = torch.softmax(alignment, dim=-1)
elif self.norm == "sigmoid":
alignment = torch.sigmoid(alignment) / torch.sigmoid(alignment).sum(dim=1).unsqueeze(1)
else:
raise RuntimeError("Unknown value for attention norm type")
context = torch.bmm(alignment.unsqueeze(1), annots)
context = context.squeeze(1)
return rnn_output, context, alignment

View File

@ -302,7 +302,7 @@ class Decoder(nn.Module):
"""
def __init__(self, in_features, memory_dim, r, memory_size,
attn_windowing):
attn_windowing, attn_norm):
super(Decoder, self).__init__()
self.r = r
self.in_features = in_features
@ -319,7 +319,8 @@ class Decoder(nn.Module):
annot_dim=in_features,
memory_dim=128,
align_model='ls',
windowing=attn_windowing)
windowing=attn_windowing,
norm=attn_norm)
# (processed_memory | attention context) -> |Linear| -> decoder_RNN_input
self.project_to_decoder_in = nn.Linear(256 + in_features, 256)
# decoder_RNN_input -> |RNN| -> RNN_state

View File

@ -112,7 +112,7 @@ class LocationLayer(nn.Module):
class Attention(nn.Module):
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
attention_location_n_filters, attention_location_kernel_size,
windowing):
windowing, norm):
super(Attention, self).__init__()
self.query_layer = Linear(
attention_rnn_dim, attention_dim, bias=False, init_gain='tanh')
@ -128,6 +128,7 @@ class Attention(nn.Module):
self.win_back = 1
self.win_front = 3
self.win_idx = None
self.norm = norm
def init_win_idx(self):
self.win_idx = -1
@ -163,8 +164,13 @@ class Attention(nn.Module):
attention[:, 0] = attention.max()
# Update the window
self.win_idx = torch.argmax(attention, 1).long()[0].item()
alignment = torch.sigmoid(attention) / torch.sigmoid(
attention).sum(dim=1).unsqueeze(1)
if self.norm == "softmax":
alignment = torch.softmax(attention, dim=-1)
elif self.norm == "sigmoid":
alignment = torch.sigmoid(attention) / torch.sigmoid(
attention).sum(dim=1).unsqueeze(1)
else:
raise RuntimeError("Unknown value for attention norm type")
context = torch.bmm(alignment.unsqueeze(1), inputs)
context = context.squeeze(1)
return context, alignment
@ -237,7 +243,7 @@ class Encoder(nn.Module):
# adapted from https://github.com/NVIDIA/tacotron2/
class Decoder(nn.Module):
def __init__(self, in_features, inputs_dim, r, attn_win):
def __init__(self, in_features, inputs_dim, r, attn_win, attn_norm):
super(Decoder, self).__init__()
self.mel_channels = inputs_dim
self.r = r
@ -257,7 +263,7 @@ class Decoder(nn.Module):
self.attention_rnn_dim)
self.attention_layer = Attention(self.attention_rnn_dim, in_features,
128, 32, 31, attn_win)
128, 32, 31, attn_win, attn_norm)
self.decoder_rnn = nn.LSTMCell(self.attention_rnn_dim + in_features,
self.decoder_rnn_dim, 1)

View File

@ -14,7 +14,8 @@ class Tacotron(nn.Module):
r=5,
padding_idx=None,
memory_size=5,
attn_windowing=False):
attn_windowing=False,
attn_norm="sigmoid"):
super(Tacotron, self).__init__()
self.r = r
self.mel_dim = mel_dim
@ -22,7 +23,7 @@ class Tacotron(nn.Module):
self.embedding = nn.Embedding(num_chars, 256, padding_idx=padding_idx)
self.embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(256)
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_windowing)
self.decoder = Decoder(256, mel_dim, r, memory_size, attn_windowing, attn_norm)
self.postnet = PostCBHG(mel_dim)
self.last_linear = nn.Sequential(
nn.Linear(self.postnet.cbhg.gru_features * 2, linear_dim),

View File

@ -9,7 +9,7 @@ from utils.generic_utils import sequence_mask
# TODO: match function arguments with tacotron
class Tacotron2(nn.Module):
def __init__(self, num_chars, r, attn_win=False):
def __init__(self, num_chars, r, attn_win=False, attn_norm="softmax"):
super(Tacotron2, self).__init__()
self.n_mel_channels = 80
self.n_frames_per_step = r
@ -18,7 +18,7 @@ class Tacotron2(nn.Module):
val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val)
self.encoder = Encoder(512)
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win)
self.decoder = Decoder(512, self.n_mel_channels, r, attn_win, attn_norm)
self.postnet = Postnet(self.n_mel_channels)
def shape_outputs(self, mel_outputs, mel_outputs_postnet, alignments):

View File

@ -38,7 +38,7 @@ class CBHGTests(unittest.TestCase):
class DecoderTests(unittest.TestCase):
def test_in_out(self):
layer = Decoder(in_features=256, memory_dim=80, r=2, memory_size=4, attn_windowing=False)
layer = Decoder(in_features=256, memory_dim=80, r=2, memory_size=4, attn_windowing=False, attn_norm="sigmoid")
dummy_input = T.rand(4, 8, 256)
dummy_memory = T.rand(4, 2, 80)

View File

@ -375,7 +375,7 @@ def main(args):
init_distributed(args.rank, num_gpus, args.group_id,
c.distributed["backend"], c.distributed["url"])
num_chars = len(phonemes) if c.use_phonemes else len(symbols)
model = MyModel(num_chars=num_chars, r=c.r)
model = MyModel(num_chars=num_chars, r=c.r, attention_norm=c.attention_norm)
print(" | > Num output units : {}".format(ap.num_freq), flush=True)