From 1ff4995eb4fd7e4cec3725d6f720c3a9ddbf0635 Mon Sep 17 00:00:00 2001 From: Keith Ito Date: Sat, 17 Mar 2018 13:48:45 -0700 Subject: [PATCH] Use location-sensitive attention (hybrid) --- models/attention.py | 59 +++++++++++++++++++++++++++++++++++++++++++++ models/tacotron.py | 5 ++-- 2 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 models/attention.py diff --git a/models/attention.py b/models/attention.py new file mode 100644 index 0000000..fb8a059 --- /dev/null +++ b/models/attention.py @@ -0,0 +1,59 @@ +import tensorflow as tf +from tensorflow.contrib.seq2seq import BahdanauAttention + + +class LocationSensitiveAttention(BahdanauAttention): + '''Implements Location Sensitive Attention from: + Chorowski, Jan et al. 'Attention-Based Models for Speech Recognition' + https://arxiv.org/abs/1506.07503 + ''' + def __init__(self, + num_units, + memory, + memory_sequence_length=None, + filters=20, + kernel_size=7, + name='LocationSensitiveAttention'): + '''Construct the Attention mechanism. See superclass for argument details.''' + super(LocationSensitiveAttention, self).__init__( + num_units, + memory, + memory_sequence_length=memory_sequence_length, + name=name) + self.location_conv = tf.layers.Conv1D( + filters, kernel_size, padding='same', use_bias=False, name='location_conv') + self.location_layer = tf.layers.Dense( + num_units, use_bias=False, dtype=tf.float32, name='location_layer') + + + def __call__(self, query, state): + '''Score the query based on the keys and values. + This replaces the superclass implementation in order to add in the location term. + Args: + query: Tensor of shape `[N, num_units]`. + state: Tensor of shape `[N, T_in]` + Returns: + alignments: Tensor of shape `[N, T_in]` + next_state: Tensor of shape `[N, T_in]` + ''' + with tf.variable_scope(None, 'location_sensitive_attention', [query]): + expanded_alignments = tf.expand_dims(state, axis=2) # [N, T_in, 1] + f = self.location_conv(expanded_alignments) # [N, T_in, 10] + processed_location = self.location_layer(f) # [N, T_in, num_units] + + processed_query = self.query_layer(query) if self.query_layer else query # [N, num_units] + processed_query = tf.expand_dims(processed_query, axis=1) # [N, 1, num_units] + score = _location_sensitive_score(processed_query, processed_location, self.keys) + alignments = self._probability_fn(score, state) + next_state = alignments + return alignments, next_state + + +def _location_sensitive_score(processed_query, processed_location, keys): + '''Location-sensitive attention score function. + Based on _bahdanau_score from tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py + ''' + # Get the number of hidden units from the trailing dimension of keys + num_units = keys.shape[2].value or array_ops.shape(keys)[2] + v = tf.get_variable('attention_v', [num_units], dtype=processed_query.dtype) + return tf.reduce_sum(v * tf.tanh(keys + processed_query + processed_location), [2]) diff --git a/models/tacotron.py b/models/tacotron.py index 5fce5f7..46b3376 100644 --- a/models/tacotron.py +++ b/models/tacotron.py @@ -1,8 +1,9 @@ import tensorflow as tf from tensorflow.contrib.rnn import GRUCell, MultiRNNCell, OutputProjectionWrapper, ResidualWrapper -from tensorflow.contrib.seq2seq import BasicDecoder, BahdanauAttention, AttentionWrapper +from tensorflow.contrib.seq2seq import BasicDecoder, AttentionWrapper from text.symbols import symbols from util.infolog import log +from .attention import LocationSensitiveAttention from .helpers import TacoTestHelper, TacoTrainingHelper from .modules import encoder_cbhg, post_cbhg, prenet from .rnn_wrappers import DecoderPrenetWrapper, ConcatOutputAndAttentionWrapper @@ -49,7 +50,7 @@ class Tacotron(): # Attention attention_cell = AttentionWrapper( DecoderPrenetWrapper(GRUCell(256), is_training), - BahdanauAttention(256, encoder_outputs), + LocationSensitiveAttention(256, encoder_outputs), alignment_history=True, output_attention=False) # [N, T_in, 256]