Use location-sensitive attention (hybrid)

pull/2/head
Keith Ito 2018-03-17 13:48:45 -07:00
parent 14960104e5
commit 1ff4995eb4
2 changed files with 62 additions and 2 deletions

59
models/attention.py Normal file
View File

@ -0,0 +1,59 @@
import tensorflow as tf
from tensorflow.contrib.seq2seq import BahdanauAttention
class LocationSensitiveAttention(BahdanauAttention):
'''Implements Location Sensitive Attention from:
Chorowski, Jan et al. 'Attention-Based Models for Speech Recognition'
https://arxiv.org/abs/1506.07503
'''
def __init__(self,
num_units,
memory,
memory_sequence_length=None,
filters=20,
kernel_size=7,
name='LocationSensitiveAttention'):
'''Construct the Attention mechanism. See superclass for argument details.'''
super(LocationSensitiveAttention, self).__init__(
num_units,
memory,
memory_sequence_length=memory_sequence_length,
name=name)
self.location_conv = tf.layers.Conv1D(
filters, kernel_size, padding='same', use_bias=False, name='location_conv')
self.location_layer = tf.layers.Dense(
num_units, use_bias=False, dtype=tf.float32, name='location_layer')
def __call__(self, query, state):
'''Score the query based on the keys and values.
This replaces the superclass implementation in order to add in the location term.
Args:
query: Tensor of shape `[N, num_units]`.
state: Tensor of shape `[N, T_in]`
Returns:
alignments: Tensor of shape `[N, T_in]`
next_state: Tensor of shape `[N, T_in]`
'''
with tf.variable_scope(None, 'location_sensitive_attention', [query]):
expanded_alignments = tf.expand_dims(state, axis=2) # [N, T_in, 1]
f = self.location_conv(expanded_alignments) # [N, T_in, 10]
processed_location = self.location_layer(f) # [N, T_in, num_units]
processed_query = self.query_layer(query) if self.query_layer else query # [N, num_units]
processed_query = tf.expand_dims(processed_query, axis=1) # [N, 1, num_units]
score = _location_sensitive_score(processed_query, processed_location, self.keys)
alignments = self._probability_fn(score, state)
next_state = alignments
return alignments, next_state
def _location_sensitive_score(processed_query, processed_location, keys):
'''Location-sensitive attention score function.
Based on _bahdanau_score from tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
'''
# Get the number of hidden units from the trailing dimension of keys
num_units = keys.shape[2].value or array_ops.shape(keys)[2]
v = tf.get_variable('attention_v', [num_units], dtype=processed_query.dtype)
return tf.reduce_sum(v * tf.tanh(keys + processed_query + processed_location), [2])

View File

@ -1,8 +1,9 @@
import tensorflow as tf
from tensorflow.contrib.rnn import GRUCell, MultiRNNCell, OutputProjectionWrapper, ResidualWrapper
from tensorflow.contrib.seq2seq import BasicDecoder, BahdanauAttention, AttentionWrapper
from tensorflow.contrib.seq2seq import BasicDecoder, AttentionWrapper
from text.symbols import symbols
from util.infolog import log
from .attention import LocationSensitiveAttention
from .helpers import TacoTestHelper, TacoTrainingHelper
from .modules import encoder_cbhg, post_cbhg, prenet
from .rnn_wrappers import DecoderPrenetWrapper, ConcatOutputAndAttentionWrapper
@ -49,7 +50,7 @@ class Tacotron():
# Attention
attention_cell = AttentionWrapper(
DecoderPrenetWrapper(GRUCell(256), is_training),
BahdanauAttention(256, encoder_outputs),
LocationSensitiveAttention(256, encoder_outputs),
alignment_history=True,
output_attention=False) # [N, T_in, 256]