2020-05-18 09:02:36 +00:00
|
|
|
import tensorflow as tf
|
|
|
|
from tensorflow import keras
|
|
|
|
from tensorflow.python.ops import math_ops
|
|
|
|
# from tensorflow_addons.seq2seq import BahdanauAttention
|
|
|
|
|
|
|
|
|
|
|
|
class Linear(keras.layers.Layer):
|
|
|
|
def __init__(self, units, use_bias, **kwargs):
|
|
|
|
super(Linear, self).__init__(**kwargs)
|
|
|
|
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer')
|
|
|
|
self.activation = keras.layers.ReLU()
|
|
|
|
|
2020-05-18 16:46:13 +00:00
|
|
|
def call(self, x):
|
2020-05-18 09:02:36 +00:00
|
|
|
"""
|
|
|
|
shapes:
|
|
|
|
x: B x T x C
|
|
|
|
"""
|
|
|
|
return self.activation(self.linear_layer(x))
|
|
|
|
|
|
|
|
|
|
|
|
class LinearBN(keras.layers.Layer):
|
|
|
|
def __init__(self, units, use_bias, **kwargs):
|
|
|
|
super(LinearBN, self).__init__(**kwargs)
|
|
|
|
self.linear_layer = keras.layers.Dense(units, use_bias=use_bias, name='linear_layer')
|
|
|
|
self.batch_normalization = keras.layers.BatchNormalization(axis=-1, momentum=0.90, epsilon=1e-5, name='batch_normalization')
|
|
|
|
self.activation = keras.layers.ReLU()
|
|
|
|
|
|
|
|
def call(self, x, training=None):
|
|
|
|
"""
|
|
|
|
shapes:
|
|
|
|
x: B x T x C
|
|
|
|
"""
|
|
|
|
out = self.linear_layer(x)
|
|
|
|
out = self.batch_normalization(out, training=training)
|
|
|
|
return self.activation(out)
|
|
|
|
|
|
|
|
|
|
|
|
class Prenet(keras.layers.Layer):
|
|
|
|
def __init__(self,
|
|
|
|
prenet_type,
|
|
|
|
prenet_dropout,
|
|
|
|
units,
|
|
|
|
bias,
|
|
|
|
**kwargs):
|
|
|
|
super(Prenet, self).__init__(**kwargs)
|
|
|
|
self.prenet_type = prenet_type
|
|
|
|
self.prenet_dropout = prenet_dropout
|
|
|
|
self.linear_layers = []
|
|
|
|
if prenet_type == "bn":
|
|
|
|
self.linear_layers += [LinearBN(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)]
|
|
|
|
elif prenet_type == "original":
|
|
|
|
self.linear_layers += [Linear(unit, use_bias=bias, name=f'linear_layer_{idx}') for idx, unit in enumerate(units)]
|
|
|
|
else:
|
|
|
|
raise RuntimeError(' [!] Unknown prenet type.')
|
|
|
|
if prenet_dropout:
|
|
|
|
self.dropout = keras.layers.Dropout(rate=0.5)
|
|
|
|
|
|
|
|
def call(self, x, training=None):
|
|
|
|
"""
|
|
|
|
shapes:
|
|
|
|
x: B x T x C
|
|
|
|
"""
|
|
|
|
for linear in self.linear_layers:
|
|
|
|
if self.prenet_dropout:
|
|
|
|
x = self.dropout(linear(x), training=training)
|
|
|
|
else:
|
|
|
|
x = linear(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def _sigmoid_norm(score):
|
|
|
|
attn_weights = tf.nn.sigmoid(score)
|
|
|
|
attn_weights = attn_weights / tf.reduce_sum(attn_weights, axis=1, keepdims=True)
|
|
|
|
return attn_weights
|
|
|
|
|
|
|
|
|
|
|
|
class Attention(keras.layers.Layer):
|
2020-05-18 16:46:13 +00:00
|
|
|
"""TODO: implement forward_attention
|
|
|
|
TODO: location sensitive attention
|
|
|
|
TODO: implement attention windowing """
|
2020-05-18 09:02:36 +00:00
|
|
|
def __init__(self, attn_dim, use_loc_attn, loc_attn_n_filters,
|
|
|
|
loc_attn_kernel_size, use_windowing, norm, use_forward_attn,
|
|
|
|
use_trans_agent, use_forward_attn_mask, **kwargs):
|
|
|
|
super(Attention, self).__init__(**kwargs)
|
|
|
|
self.use_loc_attn = use_loc_attn
|
|
|
|
self.loc_attn_n_filters = loc_attn_n_filters
|
|
|
|
self.loc_attn_kernel_size = loc_attn_kernel_size
|
|
|
|
self.use_windowing = use_windowing
|
|
|
|
self.norm = norm
|
|
|
|
self.use_forward_attn = use_forward_attn
|
|
|
|
self.use_trans_agent = use_trans_agent
|
|
|
|
self.use_forward_attn_mask = use_forward_attn_mask
|
|
|
|
self.query_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name='query_layer/linear_layer')
|
|
|
|
self.inputs_layer = tf.keras.layers.Dense(attn_dim, use_bias=False, name=f'{self.name}/inputs_layer/linear_layer')
|
|
|
|
self.v = tf.keras.layers.Dense(1, use_bias=True, name='v/linear_layer')
|
|
|
|
if use_loc_attn:
|
|
|
|
self.location_conv1d = keras.layers.Conv1D(
|
|
|
|
filters=loc_attn_n_filters,
|
|
|
|
kernel_size=loc_attn_kernel_size,
|
|
|
|
padding='same',
|
|
|
|
use_bias=False,
|
|
|
|
name='location_layer/location_conv1d')
|
|
|
|
self.location_dense = keras.layers.Dense(attn_dim, use_bias=False, name='location_layer/location_dense')
|
|
|
|
if norm == 'softmax':
|
|
|
|
self.norm_func = tf.nn.softmax
|
|
|
|
elif norm == 'sigmoid':
|
|
|
|
self.norm_func = _sigmoid_norm
|
|
|
|
else:
|
|
|
|
raise ValueError("Unknown value for attention norm type")
|
|
|
|
|
|
|
|
def init_states(self, batch_size, value_length):
|
2020-07-08 08:22:46 +00:00
|
|
|
states = []
|
2020-05-18 09:02:36 +00:00
|
|
|
if self.use_loc_attn:
|
|
|
|
attention_cum = tf.zeros([batch_size, value_length])
|
|
|
|
attention_old = tf.zeros([batch_size, value_length])
|
2020-07-08 08:22:46 +00:00
|
|
|
states = [attention_cum, attention_old]
|
|
|
|
if self.use_forward_attn:
|
|
|
|
alpha = tf.concat(
|
|
|
|
[tf.ones([batch_size, 1]),
|
|
|
|
tf.zeros([batch_size, value_length])[:, :-1] + 1e-7], axis=1)
|
|
|
|
states.append(alpha)
|
|
|
|
return tuple(states)
|
2020-05-18 09:02:36 +00:00
|
|
|
|
|
|
|
def process_values(self, values):
|
|
|
|
""" cache values for decoder iterations """
|
2020-05-18 16:46:13 +00:00
|
|
|
#pylint: disable=attribute-defined-outside-init
|
2020-05-18 09:02:36 +00:00
|
|
|
self.processed_values = self.inputs_layer(values)
|
|
|
|
self.values = values
|
|
|
|
|
|
|
|
def get_loc_attn(self, query, states):
|
|
|
|
""" compute location attention, query layer and
|
|
|
|
unnorm. attention weights"""
|
2020-07-08 08:22:46 +00:00
|
|
|
attention_cum, attention_old = states[:2]
|
2020-05-18 16:46:13 +00:00
|
|
|
attn_cat = tf.stack([attention_old, attention_cum], axis=2)
|
2020-05-18 09:02:36 +00:00
|
|
|
|
|
|
|
processed_query = self.query_layer(tf.expand_dims(query, 1))
|
|
|
|
processed_attn = self.location_dense(self.location_conv1d(attn_cat))
|
|
|
|
score = self.v(
|
|
|
|
tf.nn.tanh(self.processed_values + processed_query +
|
|
|
|
processed_attn))
|
|
|
|
score = tf.squeeze(score, axis=2)
|
|
|
|
return score, processed_query
|
|
|
|
|
|
|
|
def get_attn(self, query):
|
|
|
|
""" compute query layer and unnormalized attention weights """
|
|
|
|
processed_query = self.query_layer(tf.expand_dims(query, 1))
|
|
|
|
score = self.v(tf.nn.tanh(self.processed_values + processed_query))
|
|
|
|
score = tf.squeeze(score, axis=2)
|
|
|
|
return score, processed_query
|
|
|
|
|
2020-05-18 16:46:13 +00:00
|
|
|
def apply_score_masking(self, score, mask): #pylint: disable=no-self-use
|
2020-05-18 09:02:36 +00:00
|
|
|
""" ignore sequence paddings """
|
|
|
|
padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
|
|
|
|
# Bias so padding positions do not contribute to attention distribution.
|
|
|
|
score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
|
|
|
|
return score
|
|
|
|
|
2020-07-08 08:23:28 +00:00
|
|
|
def apply_forward_attention(self, alignment, alpha):
|
|
|
|
# forward attention
|
|
|
|
fwd_shifted_alpha = tf.pad(alpha[:, :-1], ((0, 0), (1, 0)))
|
|
|
|
# compute transition potentials
|
|
|
|
new_alpha = ((1 - 0.5) * alpha + 0.5 * fwd_shifted_alpha + 1e-8) * alignment
|
|
|
|
# renormalize attention weights
|
|
|
|
new_alpha = new_alpha / tf.reduce_sum(new_alpha, axis=1, keepdims=True)
|
|
|
|
return new_alpha
|
|
|
|
|
|
|
|
def update_states(self, old_states, scores_norm, attn_weights, new_alpha=None):
|
|
|
|
states = []
|
|
|
|
if self.use_loc_attn:
|
|
|
|
states = [old_states[0] + scores_norm, attn_weights]
|
|
|
|
if self.use_forward_attn:
|
|
|
|
states.append(new_alpha)
|
|
|
|
return tuple(states)
|
|
|
|
|
2020-05-18 09:02:36 +00:00
|
|
|
def call(self, query, states):
|
|
|
|
"""
|
|
|
|
shapes:
|
|
|
|
query: B x D
|
|
|
|
"""
|
|
|
|
if self.use_loc_attn:
|
2020-05-18 16:46:13 +00:00
|
|
|
score, _ = self.get_loc_attn(query, states)
|
2020-05-18 09:02:36 +00:00
|
|
|
else:
|
2020-05-18 16:46:13 +00:00
|
|
|
score, _ = self.get_attn(query)
|
2020-05-18 09:02:36 +00:00
|
|
|
|
|
|
|
# TODO: masking
|
|
|
|
# if mask is not None:
|
2020-05-18 16:46:13 +00:00
|
|
|
# self.apply_score_masking(score, mask)
|
2020-05-18 09:02:36 +00:00
|
|
|
# attn_weights shape == (batch_size, max_length, 1)
|
|
|
|
|
2020-07-08 08:23:28 +00:00
|
|
|
# normalize attention scores
|
|
|
|
scores_norm = self.norm_func(score)
|
|
|
|
attn_weights = scores_norm
|
2020-05-18 09:02:36 +00:00
|
|
|
|
2020-07-08 08:23:28 +00:00
|
|
|
# apply forward attention
|
|
|
|
new_alpha = None
|
|
|
|
if self.use_forward_attn:
|
|
|
|
new_alpha = self.apply_forward_attention(attn_weights, states[-1])
|
|
|
|
attn_weights = new_alpha
|
|
|
|
|
|
|
|
# update states tuple
|
|
|
|
# states = (cum_attn_weights, attn_weights, new_alpha)
|
|
|
|
states = self.update_states(states, scores_norm, attn_weights, new_alpha)
|
2020-05-18 09:02:36 +00:00
|
|
|
|
|
|
|
# context_vector shape after sum == (batch_size, hidden_size)
|
|
|
|
context_vector = tf.matmul(tf.expand_dims(attn_weights, axis=2), self.values, transpose_a=True, transpose_b=False)
|
|
|
|
context_vector = tf.squeeze(context_vector, axis=1)
|
|
|
|
return context_vector, attn_weights, states
|
|
|
|
|
|
|
|
|
|
|
|
# def _location_sensitive_score(processed_query, keys, processed_loc, attention_v, attention_b):
|
|
|
|
# dtype = processed_query.dtype
|
|
|
|
# num_units = keys.shape[-1].value or array_ops.shape(keys)[-1]
|
|
|
|
# return tf.reduce_sum(attention_v * tf.tanh(keys + processed_query + processed_loc + attention_b), [2])
|
|
|
|
|
|
|
|
|
|
|
|
# class LocationSensitiveAttention(BahdanauAttention):
|
|
|
|
# def __init__(self,
|
|
|
|
# units,
|
|
|
|
# memory=None,
|
|
|
|
# memory_sequence_length=None,
|
|
|
|
# normalize=False,
|
|
|
|
# probability_fn="softmax",
|
|
|
|
# kernel_initializer="glorot_uniform",
|
|
|
|
# dtype=None,
|
|
|
|
# name="LocationSensitiveAttention",
|
|
|
|
# location_attention_filters=32,
|
|
|
|
# location_attention_kernel_size=31):
|
|
|
|
|
|
|
|
# super(LocationSensitiveAttention,
|
|
|
|
# self).__init__(units=units,
|
|
|
|
# memory=memory,
|
|
|
|
# memory_sequence_length=memory_sequence_length,
|
|
|
|
# normalize=normalize,
|
|
|
|
# probability_fn='softmax', ## parent module default
|
|
|
|
# kernel_initializer=kernel_initializer,
|
|
|
|
# dtype=dtype,
|
|
|
|
# name=name)
|
|
|
|
# if probability_fn == 'sigmoid':
|
|
|
|
# self.probability_fn = lambda score, _: self._sigmoid_normalization(score)
|
|
|
|
# self.location_conv = keras.layers.Conv1D(filters=location_attention_filters, kernel_size=location_attention_kernel_size, padding='same', use_bias=False)
|
|
|
|
# self.location_dense = keras.layers.Dense(units, use_bias=False)
|
|
|
|
# # self.v = keras.layers.Dense(1, use_bias=True)
|
|
|
|
|
|
|
|
# def _location_sensitive_score(self, processed_query, keys, processed_loc):
|
|
|
|
# processed_query = tf.expand_dims(processed_query, 1)
|
|
|
|
# return tf.reduce_sum(self.attention_v * tf.tanh(keys + processed_query + processed_loc), [2])
|
|
|
|
|
|
|
|
# def _location_sensitive(self, alignment_cum, alignment_old):
|
|
|
|
# alignment_cat = tf.stack([alignment_cum, alignment_old], axis=2)
|
|
|
|
# return self.location_dense(self.location_conv(alignment_cat))
|
|
|
|
|
|
|
|
# def _sigmoid_normalization(self, score):
|
|
|
|
# return tf.nn.sigmoid(score) / tf.reduce_sum(tf.nn.sigmoid(score), axis=-1, keepdims=True)
|
|
|
|
|
|
|
|
# # def _apply_masking(self, score, mask):
|
|
|
|
# # padding_mask = tf.expand_dims(math_ops.logical_not(mask), 2)
|
|
|
|
# # # Bias so padding positions do not contribute to attention distribution.
|
|
|
|
# # score -= 1.e9 * math_ops.cast(padding_mask, dtype=tf.float32)
|
|
|
|
# # return score
|
|
|
|
|
|
|
|
# def _calculate_attention(self, query, state):
|
|
|
|
# alignment_cum, alignment_old = state[:2]
|
|
|
|
# processed_query = self.query_layer(
|
|
|
|
# query) if self.query_layer else query
|
|
|
|
# processed_loc = self._location_sensitive(alignment_cum, alignment_old)
|
|
|
|
# score = self._location_sensitive_score(
|
|
|
|
# processed_query,
|
|
|
|
# self.keys,
|
|
|
|
# processed_loc)
|
|
|
|
# alignment = self.probability_fn(score, state)
|
|
|
|
# alignment_cum = alignment_cum + alignment
|
|
|
|
# state[0] = alignment_cum
|
|
|
|
# state[1] = alignment
|
|
|
|
# return alignment, state
|
|
|
|
|
|
|
|
# def compute_context(self, alignments):
|
|
|
|
# expanded_alignments = tf.expand_dims(alignments, 1)
|
|
|
|
# context = tf.matmul(expanded_alignments, self.values)
|
|
|
|
# context = tf.squeeze(context, [1])
|
|
|
|
# return context
|
|
|
|
|
|
|
|
# # def call(self, query, state):
|
|
|
|
# # alignment, next_state = self._calculate_attention(query, state)
|
|
|
|
# # return alignment, next_state
|