mirror of https://github.com/MycroftAI/mimic2.git
60 lines
2.6 KiB
Python
60 lines
2.6 KiB
Python
import tensorflow as tf
|
|
from tensorflow.contrib.seq2seq import BahdanauAttention
|
|
|
|
|
|
class LocationSensitiveAttention(BahdanauAttention):
|
|
'''Implements Location Sensitive Attention from:
|
|
Chorowski, Jan et al. 'Attention-Based Models for Speech Recognition'
|
|
https://arxiv.org/abs/1506.07503
|
|
'''
|
|
def __init__(self,
|
|
num_units,
|
|
memory,
|
|
memory_sequence_length=None,
|
|
filters=20,
|
|
kernel_size=7,
|
|
name='LocationSensitiveAttention'):
|
|
'''Construct the Attention mechanism. See superclass for argument details.'''
|
|
super(LocationSensitiveAttention, self).__init__(
|
|
num_units,
|
|
memory,
|
|
memory_sequence_length=memory_sequence_length,
|
|
name=name)
|
|
self.location_conv = tf.layers.Conv1D(
|
|
filters, kernel_size, padding='same', use_bias=False, name='location_conv')
|
|
self.location_layer = tf.layers.Dense(
|
|
num_units, use_bias=False, dtype=tf.float32, name='location_layer')
|
|
|
|
|
|
def __call__(self, query, state):
|
|
'''Score the query based on the keys and values.
|
|
This replaces the superclass implementation in order to add in the location term.
|
|
Args:
|
|
query: Tensor of shape `[N, num_units]`.
|
|
state: Tensor of shape `[N, T_in]`
|
|
Returns:
|
|
alignments: Tensor of shape `[N, T_in]`
|
|
next_state: Tensor of shape `[N, T_in]`
|
|
'''
|
|
with tf.variable_scope(None, 'location_sensitive_attention', [query]):
|
|
expanded_alignments = tf.expand_dims(state, axis=2) # [N, T_in, 1]
|
|
f = self.location_conv(expanded_alignments) # [N, T_in, 10]
|
|
processed_location = self.location_layer(f) # [N, T_in, num_units]
|
|
|
|
processed_query = self.query_layer(query) if self.query_layer else query # [N, num_units]
|
|
processed_query = tf.expand_dims(processed_query, axis=1) # [N, 1, num_units]
|
|
score = _location_sensitive_score(processed_query, processed_location, self.keys)
|
|
alignments = self._probability_fn(score, state)
|
|
next_state = alignments
|
|
return alignments, next_state
|
|
|
|
|
|
def _location_sensitive_score(processed_query, processed_location, keys):
|
|
'''Location-sensitive attention score function.
|
|
Based on _bahdanau_score from tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
|
|
'''
|
|
# Get the number of hidden units from the trailing dimension of keys
|
|
num_units = keys.shape[2].value or array_ops.shape(keys)[2]
|
|
v = tf.get_variable('attention_v', [num_units], dtype=processed_query.dtype)
|
|
return tf.reduce_sum(v * tf.tanh(keys + processed_query + processed_location), [2])
|