mimic2/models/attention.py

60 lines
2.6 KiB
Python

import tensorflow as tf
from tensorflow.contrib.seq2seq import BahdanauAttention
class LocationSensitiveAttention(BahdanauAttention):
'''Implements Location Sensitive Attention from:
Chorowski, Jan et al. 'Attention-Based Models for Speech Recognition'
https://arxiv.org/abs/1506.07503
'''
def __init__(self,
num_units,
memory,
memory_sequence_length=None,
filters=20,
kernel_size=7,
name='LocationSensitiveAttention'):
'''Construct the Attention mechanism. See superclass for argument details.'''
super(LocationSensitiveAttention, self).__init__(
num_units,
memory,
memory_sequence_length=memory_sequence_length,
name=name)
self.location_conv = tf.layers.Conv1D(
filters, kernel_size, padding='same', use_bias=False, name='location_conv')
self.location_layer = tf.layers.Dense(
num_units, use_bias=False, dtype=tf.float32, name='location_layer')
def __call__(self, query, state):
'''Score the query based on the keys and values.
This replaces the superclass implementation in order to add in the location term.
Args:
query: Tensor of shape `[N, num_units]`.
state: Tensor of shape `[N, T_in]`
Returns:
alignments: Tensor of shape `[N, T_in]`
next_state: Tensor of shape `[N, T_in]`
'''
with tf.variable_scope(None, 'location_sensitive_attention', [query]):
expanded_alignments = tf.expand_dims(state, axis=2) # [N, T_in, 1]
f = self.location_conv(expanded_alignments) # [N, T_in, 10]
processed_location = self.location_layer(f) # [N, T_in, num_units]
processed_query = self.query_layer(query) if self.query_layer else query # [N, num_units]
processed_query = tf.expand_dims(processed_query, axis=1) # [N, 1, num_units]
score = _location_sensitive_score(processed_query, processed_location, self.keys)
alignments = self._probability_fn(score, state)
next_state = alignments
return alignments, next_state
def _location_sensitive_score(processed_query, processed_location, keys):
'''Location-sensitive attention score function.
Based on _bahdanau_score from tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py
'''
# Get the number of hidden units from the trailing dimension of keys
num_units = keys.shape[2].value or array_ops.shape(keys)[2]
v = tf.get_variable('attention_v', [num_units], dtype=processed_query.dtype)
return tf.reduce_sum(v * tf.tanh(keys + processed_query + processed_location), [2])